Skip to content

Commit 36f2b66

Browse files
author
mkulakow
committed
Support functions in responses api
1 parent d85b89e commit 36f2b66

2 files changed

Lines changed: 187 additions & 19 deletions

File tree

src/llm/apis/openai_responses.cpp

Lines changed: 186 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,43 @@ absl::Status OpenAIResponsesHandler::parseInput(std::optional<std::string> allow
9595
}
9696

9797
auto itemObj = item.GetObject();
98+
99+
// Determine item type (if present)
100+
auto itemTypeIt = itemObj.FindMember("type");
101+
const std::string itemType = (itemTypeIt != itemObj.MemberEnd() && itemTypeIt->value.IsString())
102+
? itemTypeIt->value.GetString() : "";
103+
104+
// Skip reasoning items — they are internal chain-of-thought and not passed to the model
105+
if (itemType == "reasoning") {
106+
continue;
107+
}
108+
109+
// Handle function_call items (assistant tool use)
110+
// For chatHistory (non-Python path), represent as an assistant message with empty content.
111+
// The proper tool_calls structure is reconstructed in processedJson for the Python/Jinja path.
112+
if (itemType == "function_call") {
113+
request.chatHistory.push_back({});
114+
request.chatHistory.last()["role"] = "assistant";
115+
request.chatHistory.last()["content"] = "";
116+
continue;
117+
}
118+
119+
// Handle function_call_output items (tool results)
120+
if (itemType == "function_call_output") {
121+
auto callIdIt = itemObj.FindMember("call_id");
122+
auto outputIt = itemObj.FindMember("output");
123+
request.chatHistory.push_back({});
124+
request.chatHistory.last()["role"] = "tool";
125+
if (callIdIt != itemObj.MemberEnd() && callIdIt->value.IsString()) {
126+
request.chatHistory.last()["tool_call_id"] = callIdIt->value.GetString();
127+
}
128+
const std::string outputContent = (outputIt != itemObj.MemberEnd() && outputIt->value.IsString())
129+
? outputIt->value.GetString() : "";
130+
request.chatHistory.last()["content"] = outputContent;
131+
continue;
132+
}
133+
134+
// All remaining items must have a role field
98135
auto roleIt = itemObj.FindMember("role");
99136
if (roleIt == itemObj.MemberEnd() || !roleIt->value.IsString()) {
100137
return absl::InvalidArgumentError("input item role is missing or invalid");
@@ -105,7 +142,9 @@ absl::Status OpenAIResponsesHandler::parseInput(std::optional<std::string> allow
105142

106143
auto contentIt = itemObj.FindMember("content");
107144
if (contentIt == itemObj.MemberEnd()) {
108-
return absl::InvalidArgumentError("input item content is missing");
145+
// Allow messages without content (e.g., assistant message paired with tool calls)
146+
request.chatHistory.last()["content"] = "";
147+
continue;
109148
}
110149

111150
if (contentIt->value.IsString()) {
@@ -117,7 +156,9 @@ absl::Status OpenAIResponsesHandler::parseInput(std::optional<std::string> allow
117156
return absl::InvalidArgumentError("input item content must be a string or array");
118157
}
119158
if (contentIt->value.GetArray().Size() == 0) {
120-
return absl::InvalidArgumentError("Invalid message structure - content array is empty");
159+
// Empty content array is allowed (e.g., assistant message with only tool calls)
160+
request.chatHistory.last()["content"] = "";
161+
continue;
121162
}
122163

123164
std::string contentText = "";
@@ -132,10 +173,10 @@ absl::Status OpenAIResponsesHandler::parseInput(std::optional<std::string> allow
132173
}
133174

134175
const std::string type = typeIt->value.GetString();
135-
if (type == "input_text") {
176+
if (type == "input_text" || type == "output_text") {
136177
auto textIt = contentObj.FindMember("text");
137178
if (textIt == contentObj.MemberEnd() || !textIt->value.IsString()) {
138-
return absl::InvalidArgumentError("input_text requires a valid text field");
179+
return absl::InvalidArgumentError(absl::StrCat(type, " requires a valid text field"));
139180
}
140181
contentText = textIt->value.GetString();
141182
} else if (type == "input_image") {
@@ -163,7 +204,8 @@ absl::Status OpenAIResponsesHandler::parseInput(std::optional<std::string> allow
163204
}
164205
request.imageHistory.push_back({i, tensorResult.value()});
165206
} else {
166-
return absl::InvalidArgumentError("Unsupported content type. Supported types are input_text and input_image.");
207+
// Skip unrecognised content item types for forward compatibility
208+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Skipping unsupported content type: {}", type);
167209
}
168210
}
169211

@@ -228,27 +270,153 @@ absl::Status OpenAIResponsesHandler::parseResponsesPart(std::optional<uint32_t>
228270
}
229271

230272
#if (PYTHON_DISABLE == 0)
231-
// Build processedJson with "messages" array from chatHistory so that
232-
// the Python chat template path (which reads request_json["messages"])
233-
// can consume Responses API input without a separate code path.
273+
// Build processedJson with a "messages" array in chat/completions format so that
274+
// the Python Jinja template path can consume Responses API input without a separate code path.
275+
// Handles reasoning (skipped), function_call (merged into assistant tool_calls), and
276+
// function_call_output (converted to role:tool messages).
234277
{
235278
Document processedDoc;
236279
processedDoc.SetObject();
237280
auto& alloc = processedDoc.GetAllocator();
238281

239282
Value messagesArray(kArrayType);
240-
for (size_t i = 0; i < request.chatHistory.size(); ++i) {
241-
Value msgObj(kObjectType);
242-
auto role = request.chatHistory[i]["role"].as_string();
243-
if (role.has_value()) {
244-
msgObj.AddMember("role", Value(role.value().c_str(), alloc), alloc);
245-
}
246-
auto content = request.chatHistory[i]["content"].as_string();
247-
if (content.has_value()) {
248-
msgObj.AddMember("content", Value(content.value().c_str(), alloc), alloc);
283+
284+
auto inputArrIt = doc.FindMember("input");
285+
if (inputArrIt != doc.MemberEnd() && inputArrIt->value.IsArray()) {
286+
// Pending function_call items to be merged into the next assistant message
287+
std::vector<const rapidjson::Value*> pendingFunctionCalls;
288+
289+
// Helper: flush pending function_calls as an assistant message with the given text content
290+
auto flushPendingFunctionCalls = [&](const std::string& textContent) {
291+
if (pendingFunctionCalls.empty()) {
292+
return;
293+
}
294+
Value msgObj(kObjectType);
295+
msgObj.AddMember("role", Value("assistant", alloc), alloc);
296+
msgObj.AddMember("content", Value(textContent.c_str(), alloc), alloc);
297+
Value toolCallsArray(kArrayType);
298+
for (const auto* fc : pendingFunctionCalls) {
299+
auto fcObj = fc->GetObject();
300+
Value tcObj(kObjectType);
301+
auto idIt = fcObj.FindMember("id");
302+
const std::string tcId = (idIt != fcObj.MemberEnd() && idIt->value.IsString())
303+
? idIt->value.GetString() : "";
304+
tcObj.AddMember("id", Value(tcId.c_str(), alloc), alloc);
305+
tcObj.AddMember("type", Value("function", alloc), alloc);
306+
Value funcObj(kObjectType);
307+
auto nameIt = fcObj.FindMember("name");
308+
const std::string funcName = (nameIt != fcObj.MemberEnd() && nameIt->value.IsString())
309+
? nameIt->value.GetString() : "";
310+
funcObj.AddMember("name", Value(funcName.c_str(), alloc), alloc);
311+
auto argsIt = fcObj.FindMember("arguments");
312+
const std::string args = (argsIt != fcObj.MemberEnd() && argsIt->value.IsString())
313+
? argsIt->value.GetString() : "";
314+
funcObj.AddMember("arguments", Value(args.c_str(), alloc), alloc);
315+
tcObj.AddMember("function", funcObj, alloc);
316+
toolCallsArray.PushBack(tcObj, alloc);
317+
}
318+
msgObj.AddMember("tool_calls", toolCallsArray, alloc);
319+
messagesArray.PushBack(msgObj, alloc);
320+
pendingFunctionCalls.clear();
321+
};
322+
323+
// Helper: extract text content from a Responses API content field (string or array)
324+
auto extractTextContent = [&](const rapidjson::Value& contentVal) -> std::string {
325+
if (contentVal.IsString()) {
326+
return contentVal.GetString();
327+
}
328+
if (contentVal.IsArray()) {
329+
for (auto& ci : contentVal.GetArray()) {
330+
if (!ci.IsObject()) continue;
331+
auto ctTypeIt = ci.GetObject().FindMember("type");
332+
if (ctTypeIt == ci.GetObject().MemberEnd() || !ctTypeIt->value.IsString()) continue;
333+
const std::string ctType = ctTypeIt->value.GetString();
334+
if (ctType == "input_text" || ctType == "output_text") {
335+
auto textIt = ci.GetObject().FindMember("text");
336+
if (textIt != ci.GetObject().MemberEnd() && textIt->value.IsString()) {
337+
return textIt->value.GetString();
338+
}
339+
}
340+
}
341+
}
342+
return "";
343+
};
344+
345+
for (rapidjson::SizeType i = 0; i < inputArrIt->value.GetArray().Size(); ++i) {
346+
const auto& item = inputArrIt->value.GetArray()[i];
347+
if (!item.IsObject()) continue;
348+
auto itemObj = item.GetObject();
349+
350+
auto itemTypeIt = itemObj.FindMember("type");
351+
const std::string itemType = (itemTypeIt != itemObj.MemberEnd() && itemTypeIt->value.IsString())
352+
? itemTypeIt->value.GetString() : "";
353+
354+
// Skip reasoning items
355+
if (itemType == "reasoning") {
356+
continue;
357+
}
358+
359+
// Buffer function_call items — they will be merged with the next assistant message
360+
if (itemType == "function_call") {
361+
pendingFunctionCalls.push_back(&item);
362+
continue;
363+
}
364+
365+
// Convert function_call_output to role:tool message
366+
if (itemType == "function_call_output") {
367+
flushPendingFunctionCalls("");
368+
Value msgObj(kObjectType);
369+
msgObj.AddMember("role", Value("tool", alloc), alloc);
370+
auto callIdIt = itemObj.FindMember("call_id");
371+
if (callIdIt != itemObj.MemberEnd() && callIdIt->value.IsString()) {
372+
msgObj.AddMember("tool_call_id", Value(callIdIt->value.GetString(), alloc), alloc);
373+
}
374+
auto outputIt = itemObj.FindMember("output");
375+
const std::string outputContent = (outputIt != itemObj.MemberEnd() && outputIt->value.IsString())
376+
? outputIt->value.GetString() : "";
377+
msgObj.AddMember("content", Value(outputContent.c_str(), alloc), alloc);
378+
messagesArray.PushBack(msgObj, alloc);
379+
continue;
380+
}
381+
382+
// All other items must have a role
383+
auto roleIt = itemObj.FindMember("role");
384+
if (roleIt == itemObj.MemberEnd() || !roleIt->value.IsString()) {
385+
continue; // Skip unknown items without a role
386+
}
387+
const std::string role = roleIt->value.GetString();
388+
389+
std::string contentText = "";
390+
auto contentIt = itemObj.FindMember("content");
391+
if (contentIt != itemObj.MemberEnd()) {
392+
contentText = extractTextContent(contentIt->value);
393+
}
394+
395+
if (role == "assistant") {
396+
if (!pendingFunctionCalls.empty()) {
397+
// Merge buffered function_call items into this assistant message
398+
flushPendingFunctionCalls(contentText);
399+
} else {
400+
// Plain assistant message with no associated tool calls
401+
Value msgObj(kObjectType);
402+
msgObj.AddMember("role", Value("assistant", alloc), alloc);
403+
msgObj.AddMember("content", Value(contentText.c_str(), alloc), alloc);
404+
messagesArray.PushBack(msgObj, alloc);
405+
}
406+
} else {
407+
// Non-assistant message — flush any pending function calls first
408+
flushPendingFunctionCalls("");
409+
Value msgObj(kObjectType);
410+
msgObj.AddMember("role", Value(role.c_str(), alloc), alloc);
411+
msgObj.AddMember("content", Value(contentText.c_str(), alloc), alloc);
412+
messagesArray.PushBack(msgObj, alloc);
413+
}
249414
}
250-
messagesArray.PushBack(msgObj, alloc);
415+
416+
// Flush any trailing buffered function_calls
417+
flushPendingFunctionCalls("");
251418
}
419+
252420
processedDoc.AddMember("messages", messagesArray, alloc);
253421

254422
// Copy tools from original doc if present

src/llm/py_jinja_template_processor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ bool PyJinjaTemplateProcessor::applyChatTemplate(PyJinjaTemplateProcessor& templ
4040
output = "Error: Chat template not loaded correctly, so it cannot be applied";
4141
return false;
4242
}
43-
43+
SPDLOG_DEBUG("Before chat template: \n {}", requestBody);
4444
py::gil_scoped_acquire acquire;
4545
try {
4646
auto locals = py::dict("request_body"_a = requestBody, "chat_template"_a = templateProcessor.chatTemplate->getObject(),

0 commit comments

Comments
 (0)