Skip to content

Commit aa9119f

Browse files
committed
Require structured post-tool responses
1 parent de7bc28 commit aa9119f

6 files changed

Lines changed: 287 additions & 25 deletions

File tree

R/console.R

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,12 @@ console_chat <- function(session = NULL,
246246
console_working_dir = working_dir,
247247
console_startup_dir = startup_dir,
248248
console_profile = profile,
249-
console_session_store_root = file.path(startup_dir, ".aisdk", "sessions")
249+
console_agent_enabled = isTRUE(agent_mode),
250+
console_session_store_root = file.path(startup_dir, ".aisdk", "sessions"),
251+
model_call_options = merge_call_options(
252+
session$get_model_call_options(),
253+
list(require_post_tool_protocol = isTRUE(agent_mode))
254+
)
250255
))
251256
session_id <- console_session_id(session)
252257
branch_tree <- console_branch_tree(session)
@@ -444,6 +449,7 @@ console_send_user_message <- function(input,
444449
generation_result <- session$send_stream(
445450
input,
446451
turn_system_prompt = turn_system_prompt,
452+
require_post_tool_protocol = isTRUE(session$get_metadata("console_agent_enabled", default = FALSE)),
447453
callback = function(text, done) {
448454
display_text <- tool_markup_filter$process(text, done)
449455
if (isTRUE(done)) {
@@ -466,7 +472,11 @@ console_send_user_message <- function(input,
466472
)
467473
} else {
468474
md_renderer <- create_markdown_stream_renderer()
469-
generation_result <- session$send(input, turn_system_prompt = turn_system_prompt)
475+
generation_result <- session$send(
476+
input,
477+
turn_system_prompt = turn_system_prompt,
478+
require_post_tool_protocol = isTRUE(session$get_metadata("console_agent_enabled", default = FALSE))
479+
)
470480
if (!is.null(generation_result$text)) {
471481
if (!is.null(app_state)) {
472482
console_app_append_assistant_text(app_state, generation_result$text)
@@ -693,6 +703,7 @@ console_continue_run_action <- function(session,
693703
action = action,
694704
guidance = guidance,
695705
stream = TRUE,
706+
require_post_tool_protocol = TRUE,
696707
callback = function(text, done) {
697708
if (isTRUE(done)) {
698709
md_renderer$process_chunk(NULL, TRUE)
@@ -708,7 +719,8 @@ console_continue_run_action <- function(session,
708719
result <- session$continue_run(
709720
action = action,
710721
guidance = guidance,
711-
stream = FALSE
722+
stream = FALSE,
723+
require_post_tool_protocol = TRUE
712724
)
713725
if (!is.null(result$text)) {
714726
if (!is.null(app_state)) {

R/core_api.R

Lines changed: 105 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -219,25 +219,68 @@ text_tool_protocol_missing <- function(result, awaiting_protocol = FALSE) {
219219
}
220220

221221
#' @keywords internal
222-
text_tool_protocol_correction_message <- function(result) {
222+
post_tool_protocol_tool_call_instruction <- function(use_text_tool_fallback = FALSE) {
223+
if (isTRUE(use_text_tool_fallback)) {
224+
return(paste(
225+
"Continue with another tool call:",
226+
"<tool_call>",
227+
"{\"name\":\"tool_name\",\"arguments\":{}}",
228+
"</tool_call>",
229+
sep = "\n"
230+
))
231+
}
232+
233+
"Continue with another tool call by using the provider's native/API tool-call interface. Do not write prose while doing so."
234+
}
235+
236+
#' @keywords internal
237+
post_tool_protocol_final_answer_instruction <- function() {
238+
paste(
239+
"Or finish the task for the user:",
240+
"<final_answer>",
241+
"Your final answer to the user.",
242+
"</final_answer>",
243+
sep = "\n"
244+
)
245+
}
246+
247+
#' @keywords internal
248+
post_tool_protocol_system_prompt <- function(use_text_tool_fallback = FALSE) {
249+
paste(
250+
"Post-tool response protocol:",
251+
"After tool results are provided, your next response must be exactly one next action and no prose outside the required structure.",
252+
post_tool_protocol_tool_call_instruction(use_text_tool_fallback = use_text_tool_fallback),
253+
post_tool_protocol_final_answer_instruction(),
254+
sep = "\n\n"
255+
)
256+
}
257+
258+
#' @keywords internal
259+
append_post_tool_protocol_message <- function(messages, use_text_tool_fallback = FALSE) {
260+
content <- paste(
261+
"Post-tool response protocol:",
262+
"Return exactly one of the following shapes and no prose outside the required structure:",
263+
post_tool_protocol_tool_call_instruction(use_text_tool_fallback = use_text_tool_fallback),
264+
post_tool_protocol_final_answer_instruction(),
265+
sep = "\n\n"
266+
)
267+
c(messages, list(list(role = "user", content = content)))
268+
}
269+
270+
#' @keywords internal
271+
text_tool_protocol_correction_message <- function(result, use_text_tool_fallback = TRUE) {
223272
preview_source <- result$final_answer_protocol_text %||%
224273
result$text_tool_call_protocol_text %||%
225274
result$text %||%
226275
""
227276
preview <- compact_text_preview(preview_source, width = 800)
228277
content <- paste(
229278
"Your previous response after tool results did not follow the required post-tool protocol.",
230-
"Do not explain the protocol. Re-emit the next action in exactly one of these forms and no prose outside the tags:",
279+
"Do not explain the protocol. Re-emit the next action in exactly one of these forms and no prose outside the required structure:",
231280
"",
232-
"Continue with another tool call:",
233-
"<tool_call>",
234-
"{\"name\":\"tool_name\",\"arguments\":{}}",
235-
"</tool_call>",
281+
post_tool_protocol_tool_call_instruction(use_text_tool_fallback = use_text_tool_fallback),
236282
"",
237-
"Or finish the task for the user:",
238-
"<final_answer>",
239-
"Your final answer to the user.",
240-
"</final_answer>",
283+
post_tool_protocol_final_answer_instruction(),
241284
if (nzchar(preview)) paste0("\nPrevious non-protocol response was:\n", preview) else NULL,
242285
sep = "\n"
243286
)
@@ -348,6 +391,10 @@ append_text_tool_result_messages <- function(messages, result, tool_results) {
348391
#' @param max_tool_result_errors Maximum number of consecutive tool result errors
349392
#' before triggering the circuit breaker. Default 2. Tool result errors are when
350393
#' tools return error messages (not exceptions). Set to Inf to disable this check.
394+
#' @param require_post_tool_protocol Logical. If TRUE, after any tool results
395+
#' are returned the model must either make another tool call or wrap its final
396+
#' answer in a `<final_answer>...</final_answer>` block. This is enabled
397+
#' automatically for text-based tool fallback.
351398
#' @param sandbox Logical. If TRUE, enables R-native programmatic sandbox mode.
352399
#' All tools are bound into an isolated R environment and replaced by a single
353400
#' `execute_r_code` meta-tool. The LLM writes R code to batch-invoke tools,
@@ -387,6 +434,7 @@ generate_text <- function(model = NULL,
387434
tools = NULL,
388435
max_steps = 1,
389436
max_tool_result_errors = 2,
437+
require_post_tool_protocol = FALSE,
390438
sandbox = FALSE,
391439
skills = NULL,
392440
session = NULL,
@@ -427,6 +475,7 @@ generate_text <- function(model = NULL,
427475

428476
tools <- filter_tools_for_model_capabilities(tools, model, session = session)
429477
use_text_tool_fallback <- !native_tool_calling_enabled(model)
478+
require_post_tool_protocol <- isTRUE(require_post_tool_protocol) || isTRUE(use_text_tool_fallback)
430479

431480
# Handle sandbox mode: bind tools into SandboxManager, replace with meta-tool
432481
if (isTRUE(sandbox) && !is.null(tools) && length(tools) > 0) {
@@ -453,6 +502,10 @@ generate_text <- function(model = NULL,
453502
system <- if (is.null(system)) tool_prompt else paste(system, "\n\n", tool_prompt, sep = "")
454503
}
455504
}
505+
if (isTRUE(require_post_tool_protocol) && !is.null(tools) && length(tools) > 0) {
506+
protocol_prompt <- post_tool_protocol_system_prompt(use_text_tool_fallback = use_text_tool_fallback)
507+
system <- if (is.null(system)) protocol_prompt else paste(system, "\n\n", protocol_prompt, sep = "")
508+
}
456509

457510
# Build initial messages
458511
messages <- build_messages(prompt, system)
@@ -476,7 +529,7 @@ generate_text <- function(model = NULL,
476529
step <- 0
477530
result <- NULL
478531
run_id <- paste0("run_", generate_stable_id("generate_text", Sys.time(), stats::runif(1)))
479-
awaiting_text_tool_protocol <- FALSE
532+
awaiting_post_tool_protocol <- FALSE
480533

481534
# Circuit breaker state
482535
breaker_state <- new.env(parent = emptyenv())
@@ -515,17 +568,20 @@ generate_text <- function(model = NULL,
515568
}
516569
}
517570

518-
if (text_tool_protocol_missing(result, awaiting_text_tool_protocol)) {
571+
if (text_tool_protocol_missing(result, awaiting_post_tool_protocol)) {
519572
if (step >= max_steps) {
520573
warning(sprintf("Maximum generation steps (%d) reached while waiting for a post-tool protocol response.", max_steps))
521574
result$finish_reason <- "tool_failure"
522575
break
523576
}
524-
messages <- c(messages, list(text_tool_protocol_correction_message(result)))
577+
messages <- c(messages, list(text_tool_protocol_correction_message(
578+
result,
579+
use_text_tool_fallback = use_text_tool_fallback
580+
)))
525581
next
526582
}
527583

528-
awaiting_text_tool_protocol <- FALSE
584+
awaiting_post_tool_protocol <- FALSE
529585

530586
# Check if there are tool calls to process
531587
if (!is.null(result$tool_calls) && length(result$tool_calls) > 0 && !is.null(tools)) {
@@ -651,7 +707,7 @@ generate_text <- function(model = NULL,
651707

652708
if (isTRUE(use_text_tool_fallback)) {
653709
messages <- append_text_tool_result_messages(messages, result, tool_results)
654-
awaiting_text_tool_protocol <- TRUE
710+
awaiting_post_tool_protocol <- TRUE
655711
} else if (history_format == "openai") {
656712
# For OpenAI, we need to include tool_calls in the assistant message
657713
assistant_message$tool_calls <- lapply(result$tool_calls, function(tc) {
@@ -682,6 +738,13 @@ generate_text <- function(model = NULL,
682738
tool_result_msg <- model$format_tool_result(tr$id, tr$name, tr$result)
683739
messages <- c(messages, list(tool_result_msg))
684740
}
741+
if (isTRUE(require_post_tool_protocol)) {
742+
messages <- append_post_tool_protocol_message(
743+
messages,
744+
use_text_tool_fallback = FALSE
745+
)
746+
awaiting_post_tool_protocol <- TRUE
747+
}
685748
}
686749

687750
if (isTRUE(tool_result_breaker_triggered)) {
@@ -763,6 +826,10 @@ generate_text <- function(model = NULL,
763826
#' @param max_tool_result_errors Maximum number of consecutive tool result errors
764827
#' before triggering the circuit breaker. Default 2. Tool result errors are when
765828
#' tools return error messages (not exceptions). Set to Inf to disable this check.
829+
#' @param require_post_tool_protocol Logical. If TRUE, after any tool results
830+
#' are returned the model must either make another tool call or wrap its final
831+
#' answer in a `<final_answer>...</final_answer>` block. This is enabled
832+
#' automatically for text-based tool fallback.
766833
#' @param sandbox Logical. If TRUE, enables R-native programmatic sandbox mode.
767834
#' See \code{generate_text} for details. Default FALSE.
768835
#' @param skills Optional path to skills directory, or a SkillRegistry object.
@@ -790,6 +857,7 @@ stream_text <- function(model = NULL,
790857
tools = NULL,
791858
max_steps = 1,
792859
max_tool_result_errors = 2,
860+
require_post_tool_protocol = FALSE,
793861
sandbox = FALSE,
794862
skills = NULL,
795863
session = NULL,
@@ -829,6 +897,7 @@ stream_text <- function(model = NULL,
829897

830898
tools <- filter_tools_for_model_capabilities(tools, model, session = session)
831899
use_text_tool_fallback <- !native_tool_calling_enabled(model)
900+
require_post_tool_protocol <- isTRUE(require_post_tool_protocol) || isTRUE(use_text_tool_fallback)
832901

833902
# Handle sandbox mode: bind tools into SandboxManager, replace with meta-tool
834903
if (isTRUE(sandbox) && !is.null(tools) && length(tools) > 0) {
@@ -855,6 +924,10 @@ stream_text <- function(model = NULL,
855924
system <- if (is.null(system)) tool_prompt else paste(system, "\n\n", tool_prompt, sep = "")
856925
}
857926
}
927+
if (isTRUE(require_post_tool_protocol) && !is.null(tools) && length(tools) > 0) {
928+
protocol_prompt <- post_tool_protocol_system_prompt(use_text_tool_fallback = use_text_tool_fallback)
929+
system <- if (is.null(system)) protocol_prompt else paste(system, "\n\n", protocol_prompt, sep = "")
930+
}
858931

859932
messages <- build_messages(prompt, system)
860933
validate_model_messages(model, messages)
@@ -876,7 +949,7 @@ stream_text <- function(model = NULL,
876949
step <- 0
877950
result <- NULL
878951
run_id <- paste0("run_", generate_stable_id("stream_text", Sys.time(), stats::runif(1)))
879-
awaiting_text_tool_protocol <- FALSE
952+
awaiting_post_tool_protocol <- FALSE
880953

881954
renderer <- create_stream_renderer()
882955

@@ -900,8 +973,8 @@ stream_text <- function(model = NULL,
900973
# Call the model via do_stream
901974
if (interactive()) renderer$start_thinking()
902975

903-
buffer_protocol_output <- isTRUE(use_text_tool_fallback) &&
904-
isTRUE(awaiting_text_tool_protocol)
976+
buffer_protocol_output <- isTRUE(require_post_tool_protocol) &&
977+
isTRUE(awaiting_post_tool_protocol)
905978

906979
result <- model$do_stream(params, function(chunk, done) {
907980
if (isTRUE(buffer_protocol_output)) {
@@ -920,13 +993,16 @@ stream_text <- function(model = NULL,
920993
result <- recover_text_tool_calls(result)
921994
result <- recover_text_final_answer(result)
922995

923-
if (text_tool_protocol_missing(result, awaiting_text_tool_protocol)) {
996+
if (text_tool_protocol_missing(result, awaiting_post_tool_protocol)) {
924997
if (step >= max_steps) {
925998
warning(sprintf("Maximum generation steps (%d) reached while waiting for a post-tool protocol response.", max_steps))
926999
result$finish_reason <- "tool_failure"
9271000
break
9281001
}
929-
messages <- c(messages, list(text_tool_protocol_correction_message(result)))
1002+
messages <- c(messages, list(text_tool_protocol_correction_message(
1003+
result,
1004+
use_text_tool_fallback = use_text_tool_fallback
1005+
)))
9301006
if (interactive()) {
9311007
renderer$reset_for_new_step()
9321008
}
@@ -950,7 +1026,7 @@ stream_text <- function(model = NULL,
9501026
}
9511027
}
9521028

953-
awaiting_text_tool_protocol <- FALSE
1029+
awaiting_post_tool_protocol <- FALSE
9541030

9551031
# Check if there are tool calls to process
9561032
if (!is.null(result$tool_calls) && length(result$tool_calls) > 0 && !is.null(tools)) {
@@ -1075,7 +1151,7 @@ stream_text <- function(model = NULL,
10751151

10761152
if (isTRUE(use_text_tool_fallback)) {
10771153
messages <- append_text_tool_result_messages(messages, result, tool_results)
1078-
awaiting_text_tool_protocol <- TRUE
1154+
awaiting_post_tool_protocol <- TRUE
10791155
} else if (history_format == "openai") {
10801156
# Provider-specific tool call formatting (copied from generate_text)
10811157
assistant_message$tool_calls <- lapply(result$tool_calls, function(tc) {
@@ -1105,6 +1181,13 @@ stream_text <- function(model = NULL,
11051181
tool_result_msg <- model$format_tool_result(tr$id, tr$name, tr$result)
11061182
messages <- c(messages, list(tool_result_msg))
11071183
}
1184+
if (isTRUE(require_post_tool_protocol)) {
1185+
messages <- append_post_tool_protocol_message(
1186+
messages,
1187+
use_text_tool_fallback = FALSE
1188+
)
1189+
awaiting_post_tool_protocol <- TRUE
1190+
}
11081191
}
11091192

11101193
# Reset renderer state for next step

man/generate_text.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/stream_text.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)