Skip to content

Commit 9b321e3

Browse files
committed
add variable ai_context_window_size
1 parent 2ab26a6 commit 9b321e3

9 files changed

Lines changed: 231 additions & 73 deletions

File tree

be/src/exprs/aggregate/aggregate_function_ai_agg.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@ class AggregateFunctionAIAggData {
3737
static constexpr const char* SEPARATOR = "\n";
3838
static constexpr uint8_t SEPARATOR_SIZE = sizeof(*SEPARATOR);
3939

40-
// 128K tokens is a relatively small context limit among mainstream AIs.
41-
// currently, token count is conservatively approximated by size; this is a safe lower bound.
42-
// a more efficient and accurate token calculation method may be introduced.
43-
static constexpr size_t MAX_CONTEXT_SIZE = 128 * 1024;
44-
4540
ColumnString::Chars data;
4641
bool inited = false;
4742

@@ -196,14 +191,22 @@ class AggregateFunctionAIAggData {
196191

197192
// handle overflow situations when adding content.
198193
bool handle_overflow(size_t additional_size) {
199-
if (additional_size + data.size() <= MAX_CONTEXT_SIZE) {
194+
const size_t max_context_size = get_ai_context_window_size();
195+
if (additional_size + data.size() <= max_context_size) {
200196
return false;
201197
}
202198

203199
process_current_context();
204200

205201
// check if there is still an overflow after replacement.
206-
return (additional_size + data.size() > MAX_CONTEXT_SIZE);
202+
return (additional_size + data.size() > max_context_size);
203+
}
204+
205+
static size_t get_ai_context_window_size() {
206+
DORIS_CHECK(_ctx);
207+
208+
int64_t context_window_size = _ctx->query_options().ai_context_window_size;
209+
return static_cast<size_t>(context_window_size > 0 ? context_window_size : 128 * 1024);
207210
}
208211

209212
void append_data(const void* source, size_t size) {

be/src/exprs/function/ai/ai_adapter.h

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ class AIAdapter {
180180
--end;
181181
}
182182

183-
if (begin < end && text[begin] == '[' && text[end - 1] == ']') {
183+
if (begin < end && text[begin] == '[' && text[end - 1] == ']' && end - begin >= 4 &&
184+
(text[begin + 1] == '"' || text[begin + 1] == '\'')) {
184185
rapidjson::Document doc;
185186
doc.Parse(text.data() + begin, end - begin);
186187
if (doc.HasParseError()) {
@@ -217,6 +218,50 @@ class AIAdapter {
217218
doc.AddMember(name, _config.dimensions, allocator);
218219
}
219220
}
221+
222+
// Validates common multimodal embedding request invariants shared by providers.
223+
Status validate_multimodal_embedding_inputs(
224+
std::string_view provider_name, const std::vector<MultimodalType>& media_types,
225+
const std::vector<std::string>& media_urls,
226+
std::initializer_list<MultimodalType> supported_types) const {
227+
if (media_urls.empty()) {
228+
return Status::InvalidArgument("{} multimodal embed inputs can not be empty",
229+
provider_name);
230+
}
231+
if (media_types.size() != media_urls.size()) {
232+
return Status::InvalidArgument(
233+
"{} multimodal embed input size mismatch, media_types={}, media_urls={}",
234+
provider_name, media_types.size(), media_urls.size());
235+
}
236+
for (MultimodalType media_type : media_types) {
237+
bool supported = false;
238+
for (MultimodalType supported_type : supported_types) {
239+
if (media_type == supported_type) {
240+
supported = true;
241+
break;
242+
}
243+
}
244+
if (!supported) [[unlikely]] {
245+
return Status::InvalidArgument(
246+
"{} only supports {} multimodal embed, got {}", provider_name,
247+
supported_multimodal_types_to_string(supported_types),
248+
multimodal_type_to_string(media_type));
249+
}
250+
}
251+
return Status::OK();
252+
}
253+
254+
static std::string supported_multimodal_types_to_string(
255+
std::initializer_list<MultimodalType> supported_types) {
256+
std::string result;
257+
for (MultimodalType type : supported_types) {
258+
if (!result.empty()) {
259+
result += "/";
260+
}
261+
result += multimodal_type_to_string(type);
262+
}
263+
return result;
264+
}
220265
};
221266

222267
// Most LLM-providers' Embedding formats are based on VoyageAI.
@@ -265,22 +310,9 @@ class VoyageAIAdapter : public AIAdapter {
265310
Status build_multimodal_embedding_request(const std::vector<MultimodalType>& media_types,
266311
const std::vector<std::string>& media_urls,
267312
std::string& request_body) const override {
268-
if (media_urls.empty()) {
269-
return Status::InvalidArgument("VoyageAI multimodal embed inputs can not be empty");
270-
}
271-
if (media_types.size() != media_urls.size()) {
272-
return Status::InvalidArgument(
273-
"VoyageAI multimodal embed input size mismatch, media_types={}, media_urls={}",
274-
media_types.size(), media_urls.size());
275-
}
276-
for (MultimodalType media_type : media_types) {
277-
if (media_type != MultimodalType::IMAGE && media_type != MultimodalType::VIDEO)
278-
[[unlikely]] {
279-
return Status::InvalidArgument(
280-
"VoyageAI only supports image/video multimodal embed, got {}",
281-
multimodal_type_to_string(media_type));
282-
}
283-
}
313+
RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
314+
"VoyageAI", media_types, media_urls,
315+
{MultimodalType::IMAGE, MultimodalType::VIDEO}));
284316
if (_config.dimensions != -1) {
285317
LOG(WARNING) << "VoyageAI multimodal embedding currently ignores dimensions parameter, "
286318
<< "model=" << _config.model_name << ", dimensions=" << _config.dimensions;
@@ -937,21 +969,8 @@ class QwenAdapter : public OpenAIAdapter {
937969
Status build_multimodal_embedding_request(const std::vector<MultimodalType>& media_types,
938970
const std::vector<std::string>& media_urls,
939971
std::string& request_body) const override {
940-
if (media_urls.empty()) {
941-
return Status::InvalidArgument("QWEN multimodal embed inputs can not be empty");
942-
}
943-
if (media_types.size() != media_urls.size()) {
944-
return Status::InvalidArgument(
945-
"QWEN multimodal embed input size mismatch, media_types={}, media_urls={}",
946-
media_types.size(), media_urls.size());
947-
}
948-
for (MultimodalType media_type : media_types) {
949-
if (media_type != MultimodalType::IMAGE && media_type != MultimodalType::VIDEO) {
950-
return Status::InvalidArgument(
951-
"QWEN only supports image/video multimodal embed, got {}",
952-
multimodal_type_to_string(media_type));
953-
}
954-
}
972+
RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
973+
"QWEN", media_types, media_urls, {MultimodalType::IMAGE, MultimodalType::VIDEO}));
955974

956975
rapidjson::Document doc;
957976
doc.SetObject();
@@ -1058,22 +1077,8 @@ class JinaAdapter : public VoyageAIAdapter {
10581077
Status build_multimodal_embedding_request(const std::vector<MultimodalType>& media_types,
10591078
const std::vector<std::string>& media_urls,
10601079
std::string& request_body) const override {
1061-
if (media_urls.empty()) {
1062-
return Status::InvalidArgument("JINA multimodal embed inputs can not be empty");
1063-
}
1064-
if (media_types.size() != media_urls.size()) {
1065-
return Status::InvalidArgument(
1066-
"JINA multimodal embed input size mismatch, media_types={}, media_urls={}",
1067-
media_types.size(), media_urls.size());
1068-
}
1069-
for (MultimodalType media_type : media_types) {
1070-
if (media_type != MultimodalType::IMAGE && media_type != MultimodalType::VIDEO)
1071-
[[unlikely]] {
1072-
return Status::InvalidArgument(
1073-
"JINA only supports image/video multimodal embed, got {}",
1074-
multimodal_type_to_string(media_type));
1075-
}
1076-
}
1080+
RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
1081+
"JINA", media_types, media_urls, {MultimodalType::IMAGE, MultimodalType::VIDEO}));
10771082

10781083
rapidjson::Document doc;
10791084
doc.SetObject();
@@ -1318,14 +1323,9 @@ class GeminiAdapter : public AIAdapter {
13181323
Status build_multimodal_embedding_request(const std::vector<MultimodalType>& media_types,
13191324
const std::vector<std::string>& media_urls,
13201325
std::string& request_body) const override {
1321-
if (media_urls.empty()) {
1322-
return Status::InvalidArgument("Gemini multimodal embed inputs can not be empty");
1323-
}
1324-
if (media_types.size() != media_urls.size()) {
1325-
return Status::InvalidArgument(
1326-
"Gemini multimodal embed input size mismatch, media_types={}, media_urls={}",
1327-
media_types.size(), media_urls.size());
1328-
}
1326+
RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
1327+
"Gemini", media_types, media_urls,
1328+
{MultimodalType::IMAGE, MultimodalType::AUDIO, MultimodalType::VIDEO}));
13291329

13301330
rapidjson::Document doc;
13311331
doc.SetObject();

be/src/exprs/function/ai/ai_functions.h

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ namespace doris {
5454
template <typename Derived>
5555
class AIFunction : public IFunction {
5656
public:
57-
static constexpr size_t max_batch_prompt_size = 128 * 1024;
58-
5957
std::string get_name() const override { return assert_cast<const Derived&>(*this).name; }
6058

6159
// If the user doesn't provide the first arg, `resource_name`
@@ -90,6 +88,17 @@ class AIFunction : public IFunction {
9088
}
9189

9290
protected:
91+
// Reads the shared AI context window size from query options. String AI batch functions and
92+
// ai_agg both use the same byte-based session variable so batching behavior stays consistent.
93+
static int64_t get_ai_context_window_size(FunctionContext* context) {
94+
DORIS_CHECK(context != nullptr);
95+
QueryContext* query_ctx = context->state()->get_query_ctx();
96+
DORIS_CHECK(query_ctx != nullptr);
97+
98+
int64_t context_window_size = query_ctx->query_options().ai_context_window_size;
99+
return context_window_size > 0 ? context_window_size : 128 * 1024;
100+
}
101+
93102
// Derived classes can override this method for non-text/default behavior.
94103
// The base implementation handles all string-input/string-output batchable functions.
95104
Status execute_with_adapter(FunctionContext* context, Block& block,
@@ -117,10 +126,18 @@ class AIFunction : public IFunction {
117126
}
118127

119128
static void normalize_endpoint(TAIResource& config) {
120-
// If users configure only the version root like `.../v1` or `.../v1beta`, append
121-
// `models/<model>:batchEmbedContents` for `embed`, and `models/<model>:generateContent`
122-
// for other AI scalar functions. If the endpoint is already a full method path, keep it.
129+
// 1. If users configure only the version root like `.../v1` or `.../v1beta`, append
130+
// `models/<model>:batchEmbedContents` for `embed`, and `models/<model>:generateContent`
131+
// for other AI scalar functions.
132+
// 2. `:embedContent` -> `:batchEmbedContents`
123133
if (iequal(config.provider_type, "GEMINI")) {
134+
if (iequal(Derived::name, "embed") && config.endpoint.ends_with(":embedContent")) {
135+
static constexpr std::string_view legacy_suffix = ":embedContent";
136+
config.endpoint.replace(config.endpoint.size() - legacy_suffix.size(),
137+
legacy_suffix.size(), ":batchEmbedContents");
138+
return;
139+
}
140+
124141
if (!config.endpoint.ends_with("v1") && !config.endpoint.ends_with("v1beta")) {
125142
return;
126143
}
@@ -270,6 +287,8 @@ class AIFunction : public IFunction {
270287
IColumn& col_result) const {
271288
std::vector<std::string> batch_prompts;
272289
size_t current_batch_size = 2; // []
290+
const size_t max_batch_prompt_size =
291+
static_cast<size_t>(get_ai_context_window_size(context));
273292

274293
for (size_t i = 0; i < input_rows_count; ++i) {
275294
std::string prompt;

be/src/exprs/function/ai/embed.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,16 @@ class FunctionEmbed : public AIFunction<FunctionEmbed> {
8787
std::vector<std::string> batch_prompts;
8888
size_t current_batch_size = 0;
8989
const int32_t max_batch_size = _get_embed_max_batch_size(context);
90+
const size_t max_context_window_size =
91+
static_cast<size_t>(get_ai_context_window_size(context));
9092

9193
for (size_t i = 0; i < input_rows_count; ++i) {
9294
std::string prompt;
9395
RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt));
9496

9597
const size_t prompt_size = prompt.size();
9698

97-
if (prompt_size > max_batch_prompt_size) {
99+
if (prompt_size > max_context_window_size) {
98100
// flush history batch
99101
RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config,
100102
adapter, context));
@@ -107,7 +109,7 @@ class FunctionEmbed : public AIFunction<FunctionEmbed> {
107109
}
108110

109111
if (!batch_prompts.empty() &&
110-
(current_batch_size + prompt_size > max_batch_prompt_size ||
112+
(current_batch_size + prompt_size > max_context_window_size ||
111113
batch_prompts.size() >= static_cast<size_t>(max_batch_size))) {
112114
RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config,
113115
adapter, context));

be/test/ai/aggregate_function_ai_agg_test.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,41 @@ TEST_F(AggregateFunctionAIAggTest, add_batch_single_place_multiple_calls_test) {
389389
_agg_function->destroy(place);
390390
}
391391

392+
TEST_F(AggregateFunctionAIAggTest, ai_context_window_size_session_variable_test) {
393+
TQueryOptions query_options = create_fake_query_options();
394+
query_options.__set_ai_context_window_size(8);
395+
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
396+
query_ctx->set_mock_ai_resource();
397+
_query_ctx = query_ctx;
398+
_agg_function->set_query_context(query_ctx.get());
399+
400+
auto resource_col = ColumnString::create();
401+
auto text_col = ColumnString::create();
402+
auto task_col = ColumnString::create();
403+
404+
resource_col->insert_data("mock_resource", 13);
405+
text_col->insert_data("abcd", 4);
406+
task_col->insert_data("summarize", 9);
407+
408+
resource_col->insert_data("mock_resource", 13);
409+
text_col->insert_data("efgh", 4);
410+
task_col->insert_data("summarize", 9);
411+
412+
std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
413+
AggregateDataPtr place = memory.get();
414+
_agg_function->create(place);
415+
416+
const IColumn* columns[3] = {resource_col.get(), text_col.get(), task_col.get()};
417+
_agg_function->add(place, columns, 0, _arena);
418+
_agg_function->add(place, columns, 1, _arena);
419+
420+
const auto& data = *reinterpret_cast<const AggregateFunctionAIAggData*>(place);
421+
std::string actual(reinterpret_cast<const char*>(data.data.data()), data.data.size());
422+
EXPECT_EQ(actual, "this is a mock response\nefgh");
423+
424+
_agg_function->destroy(place);
425+
}
426+
392427
TEST_F(AggregateFunctionAIAggTest, mock_resource_send_request_test) {
393428
std::vector<std::string> resources = {"mock_resource"};
394429
std::vector<std::string> texts = {"test input"};

be/test/ai/ai_adapter_test.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,23 @@ TEST(AI_ADAPTER_TEST, openai_adapter_responses_parse_response) {
391391
ASSERT_EQ(results[0], "openai response result");
392392
}
393393

394+
TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_keeps_mask_literals) {
395+
OpenAIAdapter adapter;
396+
std::string resp = R"({"choices":[{"message":{"content":"[MSKED]"}}]})";
397+
std::vector<std::string> results;
398+
Status st = adapter.parse_response(resp, results);
399+
ASSERT_TRUE(st.ok()) << st.to_string();
400+
ASSERT_EQ(results.size(), 1);
401+
ASSERT_EQ(results[0], "[MSKED]");
402+
403+
resp = R"({"choices":[{"message":{"content":"[MASK]"}}]})";
404+
results.clear();
405+
st = adapter.parse_response(resp, results);
406+
ASSERT_TRUE(st.ok()) << st.to_string();
407+
ASSERT_EQ(results.size(), 1);
408+
ASSERT_EQ(results[0], "[MASK]");
409+
}
410+
394411
TEST(AI_ADAPTER_TEST, gemini_adapter_request) {
395412
GeminiAdapter adapter;
396413
TAIResource config;

0 commit comments

Comments
 (0)