Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion be/src/core/column/predicate_column.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ class PredicateColumnType final : public COWHelper<IColumn, PredicateColumnType<
for (size_t i = 0; i < n; i++) {
memcpy(dst, str.data(), str.size());
insert_string_value(dst, str.size());
dst += i * str.size();
dst += str.size();
}
} else if constexpr (Type == TYPE_LARGEINT) {
const auto& v = x.get<TYPE_LARGEINT>();
Expand Down
65 changes: 43 additions & 22 deletions be/src/exprs/aggregate/aggregate_function_ai_agg.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "runtime/query_context.h"
#include "runtime/runtime_state.h"
#include "service/http/http_client.h"
#include "util/string_util.h"

namespace doris {

Expand All @@ -37,21 +38,12 @@ class AggregateFunctionAIAggData {
static constexpr const char* SEPARATOR = "\n";
static constexpr uint8_t SEPARATOR_SIZE = sizeof(*SEPARATOR);

// 128K tokens is a relatively small context limit among mainstream AIs.
// currently, token count is conservatively approximated by size; this is a safe lower bound.
// a more efficient and accurate token calculation method may be introduced.
static constexpr size_t MAX_CONTEXT_SIZE = 128 * 1024;

ColumnString::Chars data;
bool inited = false;

void add(StringRef ref) {
auto delta_size = ref.size + (inited ? SEPARATOR_SIZE : 0);
if (handle_overflow(delta_size)) {
throw Exception(ErrorCode::OUT_OF_BOUND,
"Failed to add data: combined context size exceeded "
"maximum limit even after processing");
}
handle_overflow(delta_size);
append_data(ref.data, ref.size);
}

Expand All @@ -64,11 +56,7 @@ class AggregateFunctionAIAggData {
_task = rhs._task;

size_t delta_size = (inited ? SEPARATOR_SIZE : 0) + rhs.data.size();
if (handle_overflow(delta_size)) {
throw Exception(ErrorCode::OUT_OF_BOUND,
"Failed to merge data: combined context size exceeded "
"maximum limit even after processing");
}
handle_overflow(delta_size);

if (!inited) {
inited = true;
Expand Down Expand Up @@ -151,6 +139,7 @@ class AggregateFunctionAIAggData {
throw Exception(ErrorCode::NOT_FOUND, "AI resource not found: " + resource_name);
}
_ai_config = it->second;
normalize_endpoint(_ai_config);

_ai_adapter = AIAdapterFactory::create_adapter(_ai_config.provider_type);
_ai_adapter->init(_ai_config);
Expand All @@ -161,6 +150,10 @@ class AggregateFunctionAIAggData {

const std::string& get_task() const { return _task; }

#ifdef BE_TEST
static void normalize_endpoint_for_test(AIResource& config) { normalize_endpoint(config); }
#endif

private:
Status send_request_to_ai(const std::string& request_body, std::string& response) const {
// Mock path for testing
Expand Down Expand Up @@ -194,16 +187,44 @@ class AggregateFunctionAIAggData {
return client->execute_post_request(request_body, &response);
}

Comment thread
linrrzqqq marked this conversation as resolved.
// handle overflow situations when adding content.
bool handle_overflow(size_t additional_size) {
if (additional_size + data.size() <= MAX_CONTEXT_SIZE) {
return false;
// Treat the context window as a soft batching trigger instead of a hard reject.
void handle_overflow(size_t additional_size) {
const size_t max_context_size = get_ai_context_window_size();
if (additional_size + data.size() <= max_context_size || !inited) {
return;
}

process_current_context();
}

Comment thread
linrrzqqq marked this conversation as resolved.
// check if there is still an overflow after replacement.
return (additional_size + data.size() > MAX_CONTEXT_SIZE);
static size_t get_ai_context_window_size() {
DORIS_CHECK(_ctx);

return static_cast<size_t>(_ctx->query_options().ai_context_window_size);
}

static void normalize_endpoint(AIResource& config) {
if (iequal(config.provider_type, "GEMINI")) {
if (!config.endpoint.ends_with("v1") && !config.endpoint.ends_with("v1beta")) {
return;
}

std::string model_name = config.model_name;
if (!model_name.starts_with("models/")) {
model_name = "models/" + model_name;
}

config.endpoint += "/";
config.endpoint += model_name;
config.endpoint += ":generateContent";
return;
}

if (config.endpoint.ends_with("v1/completions")) {
static constexpr std::string_view legacy_suffix = "v1/completions";
config.endpoint.replace(config.endpoint.size() - legacy_suffix.size(),
legacy_suffix.size(), "v1/chat/completions");
}
}

void append_data(const void* source, size_t size) {
Expand Down Expand Up @@ -305,4 +326,4 @@ class AggregateFunctionAIAgg final
}
};

} // namespace doris
} // namespace doris
Loading
Loading