Skip to content

Commit 340b222

Browse files
committed
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .devops/intel.Dockerfile # .github/workflows/build-android.yml # .github/workflows/build.yml # .github/workflows/release.yml # .gitignore # docs/backend/SYCL.md # docs/backend/snapdragon/README.md # examples/model-conversion/scripts/causal/convert-model.sh # ggml/CMakeLists.txt # ggml/src/CMakeLists.txt # ggml/src/ggml-hexagon/ggml-hexagon.cpp # ggml/src/ggml-hexagon/htp/CMakeLists.txt # ggml/src/ggml-hexagon/htp/hex-utils.h # ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c # ggml/src/ggml-hexagon/htp/htp-ctx.h # ggml/src/ggml-hexagon/htp/htp-ops.h # ggml/src/ggml-hexagon/htp/htp_iface.idl # ggml/src/ggml-hexagon/htp/hvx-base.h # ggml/src/ggml-hexagon/htp/main.c # ggml/src/ggml-hexagon/htp/matmul-ops.c # ggml/src/ggml-hexagon/libggml-htp.inf # ggml/src/ggml-sycl/ggml-sycl.cpp # ggml/src/ggml-sycl/mmvq.cpp # ggml/src/ggml-sycl/mmvq.hpp # ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp # ggml/src/ggml-webgpu/ggml-webgpu.cpp # ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl # ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl # ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl # scripts/server-test-structured.py # scripts/snapdragon/adb/run-bench.sh # scripts/snapdragon/adb/run-cli.sh # scripts/snapdragon/adb/run-completion.sh # scripts/snapdragon/adb/run-mtmd.sh # scripts/snapdragon/adb/run-tool.sh # scripts/snapdragon/qdc/requirements.txt # scripts/snapdragon/windows/run-bench.ps1 # scripts/snapdragon/windows/run-cli.ps1 # scripts/snapdragon/windows/run-completion.ps1 # scripts/snapdragon/windows/run-mtmd.ps1 # scripts/snapdragon/windows/run-tool.ps1 # tests/test-backend-ops.cpp # tools/cli/cli.cpp # ty.toml
2 parents 4090400 + 0adede8 commit 340b222

20 files changed

Lines changed: 260 additions & 98 deletions

common/chat.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,26 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp
558558
return tmpls->has_explicit_template;
559559
}
560560

561+
// LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list
562+
// and <|tool_call_start|>[...]<|tool_call_end|> around each tool call
563+
static bool is_lfm2_template(const std::string & src) {
564+
return src.find("<|tool_list_start|>") != std::string::npos &&
565+
src.find("<|tool_list_end|>") != std::string::npos;
566+
}
567+
568+
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates) {
569+
common_chat_prompt_preset asr_preset;
570+
asr_preset.system = "";
571+
asr_preset.user = "Transcribe audio to text";
572+
573+
if (chat_templates && chat_templates->template_default && is_lfm2_template(chat_templates->template_default->source())) {
574+
asr_preset.system = "Perform ASR.";
575+
asr_preset.user = "";
576+
}
577+
578+
return asr_preset;
579+
}
580+
561581
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) {
562582
if (!variant.empty()) {
563583
if (variant == "tool_use") {
@@ -2067,10 +2087,7 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
20672087
return common_chat_params_init_kimi_k2(tmpl, params);
20682088
}
20692089

2070-
// LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list
2071-
// and <|tool_call_start|>[...]<|tool_call_end|> around each tool call
2072-
if (src.find("<|tool_list_start|>") != std::string::npos &&
2073-
src.find("<|tool_list_end|>") != std::string::npos) {
2090+
if (is_lfm2_template(src)) {
20742091
LOG_DBG("Using specialized template: LFM2\n");
20752092
return common_chat_params_init_lfm2(tmpl, params);
20762093
}
@@ -2379,4 +2396,3 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
23792396
GGML_ASSERT(chat_templates->template_default != nullptr);
23802397
return chat_templates->template_default->caps.to_map();
23812398
}
2382-

common/chat.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,11 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
274274
const common_chat_template & tmpl,
275275
const std::string & src,
276276
autoparser::generation_params & params);
277+
278+
// specialized per-task preset
279+
struct common_chat_prompt_preset {
280+
std::string system;
281+
std::string user;
282+
};
283+
284+
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);

common/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,11 @@ inline bool string_starts_with(std::string_view str, std::string_view prefix) {
747747
str.compare(0, prefix.size(), prefix) == 0;
748748
}
749749

750+
// remove when moving to c++20
751+
inline bool string_starts_with(std::string_view str, char prefix) {
752+
return !str.empty() && str.front() == prefix;
753+
}
754+
750755
// remove when moving to c++20
751756
inline bool string_ends_with(std::string_view str, std::string_view suffix) {
752757
return str.size() >= suffix.size() &&

common/jinja/caps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include "log.h"
21
#include "value.h"
32
#include "runtime.h"
43
#include "caps.h"

common/jinja/runtime.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,16 @@ struct statement {
106106
size_t pos; // position in source, for debugging
107107
virtual ~statement() = default;
108108
virtual std::string type() const { return "Statement"; }
109+
109110
// execute_impl must be overridden by derived classes
110-
virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); }
111+
virtual value execute_impl(context &) { throw_exec_error(); }
111112
// execute is the public method to execute a statement with error handling
112113
value execute(context &);
114+
115+
private:
116+
[[noreturn]] void throw_exec_error() const {
117+
throw std::runtime_error("cannot exec " + type());
118+
}
113119
};
114120

115121
// Type Checking Utilities
@@ -143,7 +149,7 @@ struct program : public statement {
143149
program() = default;
144150
explicit program(statements && body) : body(std::move(body)) {}
145151
std::string type() const override { return "Program"; }
146-
value execute_impl(context &) override {
152+
[[noreturn]] value execute_impl(context &) override {
147153
throw std::runtime_error("Cannot execute program directly, use jinja::runtime instead");
148154
}
149155
};
@@ -195,7 +201,7 @@ struct break_statement : public statement {
195201
}
196202
};
197203

198-
value execute_impl(context &) override {
204+
[[noreturn]] value execute_impl(context &) override {
199205
throw break_statement::signal();
200206
}
201207
};
@@ -209,7 +215,7 @@ struct continue_statement : public statement {
209215
}
210216
};
211217

212-
value execute_impl(context &) override {
218+
[[noreturn]] value execute_impl(context &) override {
213219
throw continue_statement::signal();
214220
}
215221
};
@@ -509,7 +515,7 @@ struct slice_expression : public expression {
509515
chk_type<expression>(this->step_expr);
510516
}
511517
std::string type() const override { return "SliceExpression"; }
512-
value execute_impl(context &) override {
518+
[[noreturn]] value execute_impl(context &) override {
513519
throw std::runtime_error("must be handled by MemberExpression");
514520
}
515521
};

common/jinja/value.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,10 @@ static bool string_endswith(const std::string & str, const std::string & suffix)
590590
return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0;
591591
}
592592

593+
[[noreturn]] static value string_join_not_implemented(const func_args &) {
594+
throw not_implemented_exception("String join builtin not implemented");
595+
}
596+
593597
const func_builtins & value_string_t::get_builtins() const {
594598
static const func_builtins builtins = {
595599
{"default", default_value},
@@ -851,9 +855,7 @@ const func_builtins & value_string_t::get_builtins() const {
851855
res->val_str.mark_input_based_on(val_input->as_string());
852856
return res;
853857
}},
854-
{"join", [](const func_args &) -> value {
855-
throw not_implemented_exception("String join builtin not implemented");
856-
}},
858+
{"join", string_join_not_implemented},
857859
};
858860
return builtins;
859861
}
@@ -884,6 +886,9 @@ const func_builtins & value_bool_t::get_builtins() const {
884886
return builtins;
885887
}
886888

889+
[[noreturn]] static value array_unique_not_implemented(const func_args &) {
890+
throw not_implemented_exception("Array unique builtin not implemented");
891+
}
887892

888893
const func_builtins & value_array_t::get_builtins() const {
889894
static const func_builtins builtins = {
@@ -1084,13 +1089,14 @@ const func_builtins & value_array_t::get_builtins() const {
10841089
std::reverse(arr.begin(), arr.end());
10851090
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
10861091
}},
1087-
{"unique", [](const func_args &) -> value {
1088-
throw not_implemented_exception("Array unique builtin not implemented");
1089-
}},
1092+
{"unique", array_unique_not_implemented},
10901093
};
10911094
return builtins;
10921095
}
10931096

1097+
[[noreturn]] static value object_join_not_implemented(const func_args &) {
1098+
throw not_implemented_exception("object join not implemented");
1099+
}
10941100

10951101
const func_builtins & value_object_t::get_builtins() const {
10961102
if (!has_builtins) {
@@ -1183,9 +1189,7 @@ const func_builtins & value_object_t::get_builtins() const {
11831189
});
11841190
return result;
11851191
}},
1186-
{"join", [](const func_args &) -> value {
1187-
throw not_implemented_exception("object join not implemented");
1188-
}},
1192+
{"join", object_join_not_implemented},
11891193
};
11901194
return builtins;
11911195
}

common/jinja/value.h

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,27 +129,25 @@ struct value_t {
129129
// Note: only for debugging and error reporting purposes
130130
virtual std::string type() const { return ""; }
131131

132-
virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
133-
virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); }
134-
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
135-
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
136-
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
137-
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
138-
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
132+
virtual int64_t as_int() const { throw_type_error("is not an int value"); }
133+
virtual double as_float() const { throw_type_error("is not a float value"); }
134+
virtual string as_string() const { throw_type_error("is not a string value"); }
135+
virtual bool as_bool() const { throw_type_error("is not a bool value"); }
136+
virtual const std::vector<value> & as_array() const { throw_type_error("is not an array value"); }
137+
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw_type_error("is not an object value"); }
138+
virtual value invoke(const func_args &) const { throw_type_error("is not a function value"); }
139139
virtual bool is_none() const { return false; }
140140
virtual bool is_undefined() const { return false; }
141-
virtual const func_builtins & get_builtins() const {
142-
throw std::runtime_error("No builtins available for type " + type());
143-
}
141+
virtual const func_builtins & get_builtins() const { throw_type_error("has no builtins"); }
144142

145-
virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); }
146-
virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); }
147-
virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
148-
virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
149-
virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
150-
virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
151-
virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); }
152-
virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); }
143+
virtual bool has_key(const value &) { throw_type_error("is not an object value"); }
144+
virtual void insert(const value & /* key */, const value & /* val */) { throw_type_error("is not an object value"); }
145+
virtual value & at(const value & /* key */, value & /* default_val */) { throw_type_error("is not an object value"); }
146+
virtual value & at(const value & /* key */) { throw_type_error("is not an object value"); }
147+
virtual value & at(const std::string & /* key */, value & /* default_val */) { throw_type_error("is not an object value"); }
148+
virtual value & at(const std::string & /* key */) { throw_type_error("is not an object value"); }
149+
virtual value & at(int64_t /* idx */, value & /* default_val */) { throw_type_error("is not an array value"); }
150+
virtual value & at(int64_t /* idx */) { throw_type_error("is not an array value"); }
153151

154152
virtual bool is_numeric() const { return false; }
155153
virtual bool is_hashable() const { return false; }
@@ -163,6 +161,11 @@ struct value_t {
163161
// Note: only for debugging purposes
164162
virtual std::string as_repr() const { return as_string().str(); }
165163

164+
private:
165+
[[noreturn]] void throw_type_error(const char* expected) const {
166+
throw std::runtime_error(type() + " " + expected);
167+
}
168+
166169
protected:
167170
virtual bool equivalent(const value_t &) const = 0;
168171
virtual bool nonequal(const value_t & other) const { return !equivalent(other); }

convert_hf_to_gguf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,12 @@ def prepare_tensors(self):
746746

747747
if (not quant_algo or not quant_layers) and quant_config_file.is_file():
748748
with open(quant_config_file, "r", encoding="utf-8") as f:
749-
quant_config = json.load(f).get("quantization") or {}
749+
hf_quant_config = json.load(f)
750+
quant_config = hf_quant_config.get("quantization") or {}
751+
producer = hf_quant_config.get("producer") or {}
752+
producer_name = (producer.get("name") or "").lower()
753+
if quant_method is None:
754+
self.hparams.setdefault("quantization_config", {})["quant_method"] = producer_name
750755
quant_algo = quant_config.get("quant_algo", quant_algo)
751756
quant_layers = quant_config.get("quantized_layers", quant_layers) or {}
752757

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3608,6 +3608,30 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
36083608
return true;
36093609
}
36103610

3611+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR
3612+
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) {
3613+
const ggml_tensor * unary = cgraph->nodes[node_idx];
3614+
const ggml_tensor * sqr = cgraph->nodes[node_idx+1];
3615+
3616+
if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) {
3617+
return false;
3618+
}
3619+
3620+
if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
3621+
return false;
3622+
}
3623+
3624+
if (unary->type != sqr->type) {
3625+
return false;
3626+
}
3627+
3628+
if (!ggml_is_contiguous(unary->src[0])) {
3629+
return false;
3630+
}
3631+
3632+
return true;
3633+
}
3634+
36113635
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
36123636
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
36133637
const ggml_tensor *scale = cgraph->nodes[node_idx];
@@ -4116,6 +4140,12 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
41164140
continue;
41174141
}
41184142

4143+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) {
4144+
ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]);
4145+
i++;
4146+
continue;
4147+
}
4148+
41194149
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
41204150
i += 2;
41214151
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);

ggml/src/ggml-cuda/unary.cu

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ static __device__ __forceinline__ float op_sqr(float x) {
6565
return x * x;
6666
}
6767

68+
static __device__ __forceinline__ float op_relu_sqr(float x) {
69+
const float r = fmaxf(x, 0.0f);
70+
return r * r;
71+
}
72+
6873
static __device__ __forceinline__ float op_sqrt(float x) {
6974
return sqrtf(x);
7075
}
@@ -615,3 +620,21 @@ void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary
615620
GGML_ABORT("Unsupported unary op for fused unary+mul");
616621
}
617622
}
623+
624+
/* fused relu + sqr */
625+
626+
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node) {
627+
const ggml_tensor * src = relu_node->src[0];
628+
cudaStream_t stream = ctx.stream();
629+
630+
GGML_ASSERT(ggml_is_contiguous(src));
631+
GGML_ASSERT(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
632+
GGML_ASSERT(src->type == sqr_node->type);
633+
634+
const int k = ggml_nelements(src);
635+
if (src->type == GGML_TYPE_F16) {
636+
unary_cuda<op_relu_sqr>((const half *)src->data, (half *)sqr_node->data, k, stream);
637+
} else {
638+
unary_cuda<op_relu_sqr>((const float *)src->data, (float *)sqr_node->data, k, stream);
639+
}
640+
}

0 commit comments

Comments
 (0)