Skip to content

Commit 2a8afb9

Browse files
authored
[Opt](ai-func) Improving AI function performance (#62494)
### Release note Improving the performance of AI functions through batch sending, embed controls the number of (text/file) items sent in a single batch through the variable `embed_max_batch_size`, and the remaining functions internally maintain a conservative context window. The current sending format is similar to: ```json "input": [ {"role": "system", "content": "system_prompt here"}, {"role": "user", "content": [ {"idx": 1, "text": "xxx"}, {"idx": 2, "text": "xxx"}, ] } ] ``` performance: ```sql -- AI_CLASSIFY SELECT COUNT(*) AS total_rows, SUM(IF(res = 'science', 1, 0)) AS excepte_eq_res FROM ( SELECT AI_CLASSIFY('deepseek-chat', str, ['science', 'sport']) AS res FROM test_str ) t; -- before +------------+----------------+ | total_rows | excepte_eq_res | +------------+----------------+ | 100 | 100 | +------------+----------------+ 1 row in set (2 min 11.579 sec) -- now +------------+----------------+ | total_rows | excepte_eq_res | +------------+----------------+ | 100 | 100 | +------------+----------------+ 1 row in set (10.487 sec) -- AI_FILTER SELECT COUNT(*) AS total_rows, SUM(IF(res = 1, 1, 0)) AS zero_res_rows FROM ( SELECT AI_FILTER('deepseek-chat', str) AS res FROM test_str ) t; -- before +------------+---------------+ | total_rows | zero_res_rows | +------------+---------------+ | 100 | 0 | +------------+---------------+ 1 row in set (2 min 2.979 sec) -- now +------------+---------------+ | total_rows | zero_res_rows | +------------+---------------+ | 100 | 0 | +------------+---------------+ 1 row in set (5.007 sec) -- EMBED select count(embed('qwen-embed', str)) FROM test_str; -- before +---------------------------------+ | count(embed('qwen-embed', str)) | +---------------------------------+ | 100 | +---------------------------------+ 1 row in set (4 min 4.888 sec) -- now set embed_max_batch_size = 10; +---------------------------------+ | count(embed('qwen-embed', str)) | +---------------------------------+ | 100 | +---------------------------------+ 1 row in set (23.424 sec) -- Multimodal_Embed SELECT COUNT(EMBED('qwen_mul_embed', to_json(file))) FROM test_jpg2; -- before: can't get results for a long time(over 20 mins). -- now set embed_max_batch_size = 20; +----------------------------------------------------+ | .... | | 1152 | +----------------------------------------------------+ 1142 rows in set (1 min 13.577 sec) ```
1 parent 1c7e71e commit 2a8afb9

22 files changed

Lines changed: 2016 additions & 512 deletions

be/src/exprs/aggregate/aggregate_function_ai_agg.h

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "runtime/query_context.h"
3030
#include "runtime/runtime_state.h"
3131
#include "service/http/http_client.h"
32+
#include "util/string_util.h"
3233

3334
namespace doris {
3435

@@ -37,21 +38,12 @@ class AggregateFunctionAIAggData {
3738
static constexpr const char* SEPARATOR = "\n";
3839
static constexpr uint8_t SEPARATOR_SIZE = sizeof(*SEPARATOR);
3940

40-
// 128K tokens is a relatively small context limit among mainstream AIs.
41-
// currently, token count is conservatively approximated by size; this is a safe lower bound.
42-
// a more efficient and accurate token calculation method may be introduced.
43-
static constexpr size_t MAX_CONTEXT_SIZE = 128 * 1024;
44-
4541
ColumnString::Chars data;
4642
bool inited = false;
4743

4844
void add(StringRef ref) {
4945
auto delta_size = ref.size + (inited ? SEPARATOR_SIZE : 0);
50-
if (handle_overflow(delta_size)) {
51-
throw Exception(ErrorCode::OUT_OF_BOUND,
52-
"Failed to add data: combined context size exceeded "
53-
"maximum limit even after processing");
54-
}
46+
handle_overflow(delta_size);
5547
append_data(ref.data, ref.size);
5648
}
5749

@@ -64,11 +56,7 @@ class AggregateFunctionAIAggData {
6456
_task = rhs._task;
6557

6658
size_t delta_size = (inited ? SEPARATOR_SIZE : 0) + rhs.data.size();
67-
if (handle_overflow(delta_size)) {
68-
throw Exception(ErrorCode::OUT_OF_BOUND,
69-
"Failed to merge data: combined context size exceeded "
70-
"maximum limit even after processing");
71-
}
59+
handle_overflow(delta_size);
7260

7361
if (!inited) {
7462
inited = true;
@@ -151,6 +139,7 @@ class AggregateFunctionAIAggData {
151139
throw Exception(ErrorCode::NOT_FOUND, "AI resource not found: " + resource_name);
152140
}
153141
_ai_config = it->second;
142+
normalize_endpoint(_ai_config);
154143

155144
_ai_adapter = AIAdapterFactory::create_adapter(_ai_config.provider_type);
156145
_ai_adapter->init(_ai_config);
@@ -161,6 +150,10 @@ class AggregateFunctionAIAggData {
161150

162151
const std::string& get_task() const { return _task; }
163152

153+
#ifdef BE_TEST
154+
static void normalize_endpoint_for_test(AIResource& config) { normalize_endpoint(config); }
155+
#endif
156+
164157
private:
165158
Status send_request_to_ai(const std::string& request_body, std::string& response) const {
166159
// Mock path for testing
@@ -194,16 +187,44 @@ class AggregateFunctionAIAggData {
194187
return client->execute_post_request(request_body, &response);
195188
}
196189

197-
// handle overflow situations when adding content.
198-
bool handle_overflow(size_t additional_size) {
199-
if (additional_size + data.size() <= MAX_CONTEXT_SIZE) {
200-
return false;
190+
// Treat the context window as a soft batching trigger instead of a hard reject.
191+
void handle_overflow(size_t additional_size) {
192+
const size_t max_context_size = get_ai_context_window_size();
193+
if (additional_size + data.size() <= max_context_size || !inited) {
194+
return;
201195
}
202196

203197
process_current_context();
198+
}
204199

205-
// check if there is still an overflow after replacement.
206-
return (additional_size + data.size() > MAX_CONTEXT_SIZE);
200+
static size_t get_ai_context_window_size() {
201+
DORIS_CHECK(_ctx);
202+
203+
return static_cast<size_t>(_ctx->query_options().ai_context_window_size);
204+
}
205+
206+
static void normalize_endpoint(AIResource& config) {
207+
if (iequal(config.provider_type, "GEMINI")) {
208+
if (!config.endpoint.ends_with("v1") && !config.endpoint.ends_with("v1beta")) {
209+
return;
210+
}
211+
212+
std::string model_name = config.model_name;
213+
if (!model_name.starts_with("models/")) {
214+
model_name = "models/" + model_name;
215+
}
216+
217+
config.endpoint += "/";
218+
config.endpoint += model_name;
219+
config.endpoint += ":generateContent";
220+
return;
221+
}
222+
223+
if (config.endpoint.ends_with("v1/completions")) {
224+
static constexpr std::string_view legacy_suffix = "v1/completions";
225+
config.endpoint.replace(config.endpoint.size() - legacy_suffix.size(),
226+
legacy_suffix.size(), "v1/chat/completions");
227+
}
207228
}
208229

209230
void append_data(const void* source, size_t size) {
@@ -305,4 +326,4 @@ class AggregateFunctionAIAgg final
305326
}
306327
};
307328

308-
} // namespace doris
329+
} // namespace doris

0 commit comments

Comments
 (0)