Skip to content

Commit d837850

Browse files
committed
Validate tool arguments before execution
1 parent aa9119f commit d837850

12 files changed

Lines changed: 475 additions & 31 deletions

R/console_app.R

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,18 @@ console_app_record_tool_result <- function(state, name, result, success = TRUE,
541541

542542
item <- turn$tool_calls[[match_idx]]
543543
item$end_time <- Sys.time()
544+
display_status <- if (is.list(raw_result) && identical(raw_result$error_type %||% NULL, "invalid_tool_arguments")) {
545+
"invalid_arguments"
546+
} else {
547+
NULL
548+
}
544549
item$status <- if (failed) "failed" else "done"
545-
item$result_summary <- compact_tool_result_label(name, result, success = !failed)
550+
item$result_summary <- compact_tool_result_label(
551+
name,
552+
result,
553+
success = !failed,
554+
display_status = display_status
555+
)
546556
item$raw_result <- raw_result
547557

548558
diagnostics <- extract_console_tool_diagnostics(raw_result, rendered_result = result)

R/core_api.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,9 @@ generate_text <- function(model = NULL,
667667
# --- Circuit Breaker: Detect consecutive tool result errors ---
668668
# Check if any tools returned errors (not exceptions, but error results)
669669
error_count <- sum(vapply(tool_results, function(tr) {
670+
if (isTRUE(tr$is_validation_error)) {
671+
return(FALSE)
672+
}
670673
tool_result_indicates_error(tr$result, tr$raw_result %||% tr$result)
671674
}, logical(1)))
672675
tool_result_breaker_triggered <- FALSE
@@ -1110,6 +1113,9 @@ stream_text <- function(model = NULL,
11101113

11111114
# --- Circuit Breaker: Detect tool result errors ---
11121115
tool_result_error_count <- sum(vapply(tool_results, function(tr) {
1116+
if (isTRUE(tr$is_validation_error)) {
1117+
return(FALSE)
1118+
}
11131119
tool_result_indicates_error(tr$result, tr$raw_result %||% tr$result)
11141120
}, logical(1)))
11151121
tool_result_breaker_triggered <- FALSE

R/r_introspect_tools.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,10 @@ create_r_introspect_tools <- function() {
867867
"The subprocess cannot modify the user's session."
868868
),
869869
parameters = z_object(
870-
code = z_string("R code to evaluate (one or more expressions)."),
870+
code = z_string(
871+
"R code to evaluate (one or more expressions).",
872+
min_length = 1
873+
),
871874
timeout_secs = z_integer(
872875
"Maximum seconds to wait before killing the subprocess (default 30, hard cap 120).",
873876
nullable = TRUE
@@ -888,7 +891,8 @@ create_r_introspect_tools <- function() {
888891
"you to test a real API call."
889892
),
890893
nullable = TRUE
891-
)
894+
),
895+
.required = "code"
892896
),
893897
execute = function(args) {
894898
code <- args$code
@@ -941,7 +945,8 @@ create_r_introspect_tools <- function() {
941945
)
942946

943947
format_r_eval_result(captured, code = code, timeout_secs = timeout_secs)
944-
}
948+
},
949+
meta = list(validate_arguments = TRUE)
945950
),
946951
tool(
947952
name = "r_session_state",

R/schema.R

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,31 @@ NULL
1313
#' @description Create a JSON Schema for string type.
1414
#' @param description Optional description of the field.
1515
#' @param nullable If TRUE, allows null values.
16-
#' @param default Optional default value.
17-
#' @return A list representing JSON Schema for string.
18-
#' @export
19-
#' @examples
20-
#' z_string(description = "The city name")
21-
z_string <- function(description = NULL, nullable = FALSE, default = NULL) {
22-
schema <- list(type = "string")
23-
if (!is.null(description)) {
24-
schema$description <- description
25-
}
26-
if (!is.null(default)) {
27-
schema$default <- default
28-
}
29-
if (nullable) {
30-
schema$type <- c("string", "null")
31-
}
16+
#' @param default Optional default value.
17+
#' @param min_length Optional minimum string length.
18+
#' @param max_length Optional maximum string length.
19+
#' @return A list representing JSON Schema for string.
20+
#' @export
21+
#' @examples
22+
#' z_string(description = "The city name")
23+
z_string <- function(description = NULL, nullable = FALSE, default = NULL,
24+
min_length = NULL, max_length = NULL) {
25+
schema <- list(type = "string")
26+
if (!is.null(description)) {
27+
schema$description <- description
28+
}
29+
if (!is.null(default)) {
30+
schema$default <- default
31+
}
32+
if (!is.null(min_length)) {
33+
schema$minLength <- as.integer(min_length)
34+
}
35+
if (!is.null(max_length)) {
36+
schema$maxLength <- as.integer(max_length)
37+
}
38+
if (nullable) {
39+
schema$type <- c("string", "null")
40+
}
3241
class(schema) <- c("z_schema", "z_string", "list")
3342
schema
3443
}

R/tool.R

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,201 @@ tool_result_indicates_error <- function(result, raw_result = result) {
776776
FALSE
777777
}
778778

779+
#' @keywords internal
780+
tool_argument_validation_enabled <- function(tool_obj) {
781+
isTRUE(tool_obj$meta$validate_arguments)
782+
}
783+
784+
#' @keywords internal
785+
schema_allows_null <- function(schema) {
786+
"null" %in% (schema$type %||% character(0))
787+
}
788+
789+
#' @keywords internal
790+
schema_type_label <- function(schema) {
791+
paste(schema$type %||% class(schema)[[1]], collapse = " or ")
792+
}
793+
794+
#' @keywords internal
795+
validate_schema_value <- function(value, schema, path) {
796+
errors <- character(0)
797+
label <- if (nzchar(path)) paste0("`", path, "`") else "value"
798+
799+
if (is.null(value)) {
800+
if (schema_allows_null(schema)) {
801+
return(errors)
802+
}
803+
return(sprintf("Argument %s must not be null.", label))
804+
}
805+
806+
if (inherits(schema, "z_any")) {
807+
return(errors)
808+
}
809+
810+
if (inherits(schema, "z_string") || inherits(schema, "z_enum")) {
811+
if (!is.character(value) || length(value) != 1 || is.na(value)) {
812+
return(sprintf("Argument %s must be a single string.", label))
813+
}
814+
min_length <- schema$minLength %||% NULL
815+
if (!is.null(min_length) && nchar(value, type = "chars", allowNA = FALSE) < min_length) {
816+
errors <- c(errors, sprintf(
817+
"Argument %s must contain at least %d character%s.",
818+
label,
819+
min_length,
820+
if (identical(as.integer(min_length), 1L)) "" else "s"
821+
))
822+
}
823+
max_length <- schema$maxLength %||% NULL
824+
if (!is.null(max_length) && nchar(value, type = "chars", allowNA = FALSE) > max_length) {
825+
errors <- c(errors, sprintf(
826+
"Argument %s must contain at most %d character%s.",
827+
label,
828+
max_length,
829+
if (identical(as.integer(max_length), 1L)) "" else "s"
830+
))
831+
}
832+
enum_values <- unlist(schema$enum %||% list(), use.names = FALSE)
833+
if (length(enum_values) > 0 && !value %in% enum_values) {
834+
errors <- c(errors, sprintf(
835+
"Argument %s must be one of: %s.",
836+
label,
837+
paste(sprintf("`%s`", enum_values), collapse = ", ")
838+
))
839+
}
840+
return(errors)
841+
}
842+
843+
if (inherits(schema, "z_boolean")) {
844+
if (!is.logical(value) || length(value) != 1 || is.na(value)) {
845+
return(sprintf("Argument %s must be a single boolean.", label))
846+
}
847+
return(errors)
848+
}
849+
850+
if (inherits(schema, "z_integer")) {
851+
if (!is.numeric(value) || length(value) != 1 || is.na(value) || value != as.integer(value)) {
852+
return(sprintf("Argument %s must be a single integer.", label))
853+
}
854+
} else if (inherits(schema, "z_number")) {
855+
if (!is.numeric(value) || length(value) != 1 || is.na(value)) {
856+
return(sprintf("Argument %s must be a single number.", label))
857+
}
858+
}
859+
860+
if (inherits(schema, "z_integer") || inherits(schema, "z_number")) {
861+
if (!is.null(schema$minimum) && value < schema$minimum) {
862+
errors <- c(errors, sprintf("Argument %s must be >= %s.", label, schema$minimum))
863+
}
864+
if (!is.null(schema$maximum) && value > schema$maximum) {
865+
errors <- c(errors, sprintf("Argument %s must be <= %s.", label, schema$maximum))
866+
}
867+
return(errors)
868+
}
869+
870+
if (inherits(schema, "z_array")) {
871+
if (!is.list(value) && !is.atomic(value)) {
872+
return(sprintf("Argument %s must be an array.", label))
873+
}
874+
n_items <- length(value)
875+
if (!is.null(schema$minItems) && n_items < schema$minItems) {
876+
errors <- c(errors, sprintf("Argument %s must contain at least %d item(s).", label, schema$minItems))
877+
}
878+
if (!is.null(schema$maxItems) && n_items > schema$maxItems) {
879+
errors <- c(errors, sprintf("Argument %s must contain at most %d item(s).", label, schema$maxItems))
880+
}
881+
if (!is.null(schema$items) && inherits(schema$items, "z_schema")) {
882+
for (i in seq_along(value)) {
883+
errors <- c(errors, validate_schema_value(value[[i]], schema$items, paste0(path, "[", i, "]")))
884+
}
885+
}
886+
return(errors)
887+
}
888+
889+
if (inherits(schema, "z_object") || inherits(schema, "z_any_object")) {
890+
if (!is.list(value)) {
891+
return(sprintf("Argument %s must be an object.", label))
892+
}
893+
return(validate_object_arguments(value, schema, path = path))
894+
}
895+
896+
errors
897+
}
898+
899+
#' @keywords internal
900+
validate_object_arguments <- function(args, schema, path = "") {
901+
errors <- character(0)
902+
props <- schema$properties %||% list()
903+
required <- unlist(schema$required %||% character(0), use.names = FALSE)
904+
arg_names <- names(args) %||% character(0)
905+
906+
for (name in required) {
907+
if (!name %in% arg_names) {
908+
field_path <- if (nzchar(path)) paste0(path, ".", name) else name
909+
errors <- c(errors, sprintf("Missing required argument `%s`.", field_path))
910+
}
911+
}
912+
913+
for (name in intersect(names(props), arg_names)) {
914+
field_path <- if (nzchar(path)) paste0(path, ".", name) else name
915+
value <- args[[which(arg_names == name)[[1]]]]
916+
errors <- c(errors, validate_schema_value(value, props[[name]], field_path))
917+
}
918+
919+
errors
920+
}
921+
922+
#' @keywords internal
923+
validate_tool_arguments <- function(tool_obj, args) {
924+
parsed <- parse_tool_arguments(args, tool_name = tool_obj$name)
925+
926+
if (!inherits(tool_obj$parameters, "z_schema")) {
927+
return(list(valid = TRUE, arguments = parsed, errors = character(0)))
928+
}
929+
930+
errors <- validate_schema_value(parsed, tool_obj$parameters, "")
931+
list(
932+
valid = length(errors) == 0,
933+
arguments = parsed,
934+
errors = errors
935+
)
936+
}
937+
938+
#' @keywords internal
939+
tool_argument_validation_result <- function(tc, tool_obj, validation) {
940+
required <- unlist(tool_obj$parameters$required %||% character(0), use.names = FALSE)
941+
expected <- list(
942+
required = as.list(required),
943+
schema = schema_to_list(tool_obj$parameters)
944+
)
945+
payload <- list(
946+
error = TRUE,
947+
error_type = "invalid_tool_arguments",
948+
tool = tool_obj$name,
949+
message = paste(validation$errors, collapse = " "),
950+
expected = expected,
951+
received_arguments = validation$arguments,
952+
suggestion = paste(
953+
"Re-emit the tool call with arguments that match the tool schema.",
954+
"Do not switch tools solely to work around an argument validation error."
955+
)
956+
)
957+
958+
list(
959+
id = tc$id,
960+
name = tc$name,
961+
result = paste0(
962+
"Error: invalid arguments for tool '", tool_obj$name, "': ",
963+
payload$message,
964+
"\n",
965+
safe_to_json(payload, auto_unbox = TRUE)
966+
),
967+
raw_result = payload,
968+
is_error = TRUE,
969+
is_validation_error = TRUE,
970+
display_status = "invalid_arguments"
971+
)
972+
}
973+
779974
#' @title Execute Tool Calls
780975
#' @description
781976
#' Execute a list of tool calls returned by an LLM. This function safely
@@ -850,6 +1045,14 @@ execute_tool_calls <- function(tool_calls, tools, hooks = NULL, envir = NULL,
8501045
))
8511046
}
8521047

1048+
if (tool_argument_validation_enabled(tool_obj)) {
1049+
validation <- validate_tool_arguments(tool_obj, tc$arguments)
1050+
if (!isTRUE(validation$valid)) {
1051+
return(tool_argument_validation_result(tc, tool_obj, validation))
1052+
}
1053+
tc$arguments <- validation$arguments
1054+
}
1055+
8531056
# Execute the tool with error handling (including hook errors)
8541057
result_list <- tryCatch(
8551058
{

R/utils_cli.R

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ tool_result_failed <- function(result, success = TRUE) {
158158
}
159159

160160
#' @keywords internal
161-
compact_tool_result_label <- function(name, result, success = TRUE) {
161+
compact_tool_result_label <- function(name, result, success = TRUE, display_status = NULL) {
162162
failed <- tool_result_failed(result, success)
163163

164164
base <- switch(name,
@@ -172,6 +172,10 @@ compact_tool_result_label <- function(name, result, success = TRUE) {
172172
name
173173
)
174174

175+
if (identical(display_status, "invalid_arguments")) {
176+
return(paste0(base, " call had invalid arguments"))
177+
}
178+
175179
if (failed) {
176180
paste0(base, " failed")
177181
} else {
@@ -494,9 +498,14 @@ cli_tool_result <- function(name, result, success = TRUE, raw_result = result) {
494498
}
495499

496500
failed <- tool_result_failed(result, success)
501+
display_status <- if (is.list(raw_result) && identical(raw_result$error_type %||% NULL, "invalid_tool_arguments")) {
502+
"invalid_arguments"
503+
} else {
504+
NULL
505+
}
497506

498507
if (tool_log_is_compact()) {
499-
label <- compact_tool_result_label(name, result, success = !failed)
508+
label <- compact_tool_result_label(name, result, success = !failed, display_status = display_status)
500509

501510
if (!requireNamespace("cli", quietly = TRUE)) {
502511
status <- if (failed) "\u2716" else "\u2714"

R/utils_console.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ analyze_tool_failures <- function(tool_results) {
2222
failure_counts <- list()
2323

2424
for (tr in tool_results) {
25+
if (isTRUE(tr$is_validation_error)) {
26+
next
27+
}
28+
2529
tool_name <- tr$name %||% "unknown"
2630

2731
# Check if this tool result represents a failure

0 commit comments

Comments
 (0)