|
| 1 | +#' AI_applyAnonymization: Automatically apply anonymization strategy using LLM |
| 2 | +#' |
| 3 | +#' Uses an agentic loop to explore multiple anonymization strategies. |
| 4 | +#' The LLM proposes strategies as structured tool calls, each is evaluated |
| 5 | +#' with a combined utility score, and the best is selected. |
| 6 | +#' |
| 7 | +#' @param sdcObj An object of class sdcMicroObj. |
| 8 | +#' @param k Desired k-anonymity level (default 3). |
| 9 | +#' @param verbose If \code{TRUE}, prints progress and scores for each strategy. |
| 10 | +#' @param model LLM model identifier. If \code{NULL}, a default is chosen per provider. |
| 11 | +#' @param api_key API key. If \code{NULL}, auto-detected from environment variables. |
| 12 | +#' @param provider LLM provider: \code{"openai"} (default), \code{"anthropic"}, or |
| 13 | +#' \code{"custom"} for any OpenAI-compatible endpoint. |
| 14 | +#' @param base_url Base URL for the API endpoint. Required when \code{provider = "custom"}. |
| 15 | +#' @param confirm Logical; if \code{TRUE} (default) and session is interactive, shows the |
| 16 | +#' best strategy and asks for confirmation before applying. |
| 17 | +#' @param max_iter Number of refinement iterations after the initial batch (default 2). |
| 18 | +#' @param n_strategies Number of strategies in the initial batch (default 3). |
| 19 | +#' @param weights Numeric vector of length 3: weights for suppression rate, |
| 20 | +#' category loss, and IL1 in the utility score. Default \code{c(1/3, 1/3, 1/3)}. |
| 21 | +#' @param generateReport If \code{TRUE}, generates internal and external reports. |
| 22 | +#' @return Modified sdcMicroObj with the best anonymization strategy applied. |
| 23 | +#' @author Matthias Templ |
| 24 | +#' @export |
| 25 | +#' @examples |
| 26 | +#' \dontrun{ |
| 27 | +#' if (interactive() && nzchar(Sys.getenv("OPENAI_API_KEY"))) { |
| 28 | +#' library(sdcMicro) |
| 29 | +#' data(testdata) |
| 30 | +#' sdc <- AI_createSdcObj(dat = testdata, policy = "open", confirm = FALSE) |
| 31 | +#' sdc <- AI_applyAnonymization(sdcObj = sdc, k = 3, verbose = TRUE, confirm = FALSE) |
| 32 | +#' } |
| 33 | +#' } |
| 34 | +AI_applyAnonymization <- function(sdcObj, k = 3, verbose = TRUE, |
| 35 | + model = NULL, api_key = NULL, |
| 36 | + provider = c("openai", "anthropic", "custom"), |
| 37 | + base_url = NULL, confirm = TRUE, |
| 38 | + max_iter = 2, n_strategies = 3, |
| 39 | + weights = c(1/3, 1/3, 1/3), |
| 40 | + generateReport = TRUE) { |
| 41 | + provider <- match.arg(provider) |
| 42 | + summary_info <- summarize_sdcObj_structure(sdcObj, k) |
| 43 | + tool_schemas <- get_tool_schemas() |
| 44 | + |
| 45 | + # Common LLM call args |
| 46 | + llm_args <- list(provider = provider, model = model, api_key = api_key, base_url = base_url) |
| 47 | + |
| 48 | + # ====== Phase 1: Batch ====== |
| 49 | + if (verbose) message("=== Batch phase: requesting ", n_strategies, " strategies ===") |
| 50 | + |
| 51 | + prompt <- build_agentic_prompt(summary_info, k, n_strategies, tool_schemas) |
| 52 | + |
| 53 | + # Try native tool calling for openai/anthropic, text fallback for custom |
| 54 | + use_native_tools <- provider %in% c("openai", "anthropic") |
| 55 | + |
| 56 | + if (use_native_tools) { |
| 57 | + response <- query_llm( |
| 58 | + prompt = prompt, |
| 59 | + system_prompt = "You are an expert in statistical disclosure control.", |
| 60 | + provider = provider, model = model, api_key = api_key, |
| 61 | + base_url = base_url, tools = tool_schemas |
| 62 | + ) |
| 63 | + # Native tool calling returns tool_calls directly -- but for batch we need |
| 64 | + # multiple strategies. The prompt asks for JSON with strategies array even |
| 65 | + # with native tools, so we may get text content with JSON. |
| 66 | + # Parse the content as JSON strategies if tool_calls is empty |
| 67 | + if (length(response$tool_calls) > 0) { |
| 68 | + # Single strategy from tool calls |
| 69 | + strategies <- list(list( |
| 70 | + name = "native", |
| 71 | + reasoning = if (!is.null(response$content)) response$content else "", |
| 72 | + calls = response$tool_calls |
| 73 | + )) |
| 74 | + } else if (!is.null(response$content)) { |
| 75 | + strategies <- tryCatch( |
| 76 | + parse_strategies_json(response$content), |
| 77 | + error = function(e) { |
| 78 | + stop("Failed to parse batch strategies from LLM: ", e$message, call. = FALSE) |
| 79 | + } |
| 80 | + ) |
| 81 | + } else { |
| 82 | + stop("LLM returned neither tool calls nor content.", call. = FALSE) |
| 83 | + } |
| 84 | + } else { |
| 85 | + # Text/JSON fallback |
| 86 | + content <- query_llm( |
| 87 | + prompt = prompt, |
| 88 | + system_prompt = "You are an expert in statistical disclosure control.", |
| 89 | + provider = provider, model = model, api_key = api_key, |
| 90 | + base_url = base_url |
| 91 | + ) |
| 92 | + strategies <- parse_strategies_json(content) |
| 93 | + } |
| 94 | + |
| 95 | + # Evaluate each strategy |
| 96 | + results <- list() |
| 97 | + for (i in seq_along(strategies)) { |
| 98 | + strat <- strategies[[i]] |
| 99 | + if (verbose) { |
| 100 | + name <- if (!is.null(strat$name)) strat$name else paste("Strategy", i) |
| 101 | + message(sprintf(" Evaluating %s...", name)) |
| 102 | + } |
| 103 | + |
| 104 | + sdcObj_copy <- sdcObj |
| 105 | + sdcObj_copy <- tryCatch( |
| 106 | + execute_tool_calls(sdcObj_copy, strat$calls), |
| 107 | + error = function(e) { |
| 108 | + if (verbose) message(sprintf(" Failed: %s", e$message)) |
| 109 | + NULL |
| 110 | + } |
| 111 | + ) |
| 112 | + if (is.null(sdcObj_copy)) next |
| 113 | + |
| 114 | + sdcObj_copy <- localSuppression(sdcObj_copy, k = k) |
| 115 | + score <- ai_utility_score(sdcObj, sdcObj_copy, weights) |
| 116 | + |
| 117 | + results[[length(results) + 1]] <- list( |
| 118 | + strategy = strat, sdcObj = sdcObj_copy, score = score |
| 119 | + ) |
| 120 | + |
| 121 | + if (verbose) { |
| 122 | + message(sprintf(" U=%.4f (S=%.4f, C=%.4f, IL1=%.4f)", |
| 123 | + score$total, score$suppression_rate, |
| 124 | + score$category_loss, score$il1)) |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + if (length(results) == 0) { |
| 129 | + stop("All strategies failed. Cannot proceed.", call. = FALSE) |
| 130 | + } |
| 131 | + |
| 132 | + # Find best |
| 133 | + best_idx <- which.min(vapply(results, function(r) r$score$total, numeric(1))) |
| 134 | + best <- results[[best_idx]] |
| 135 | + |
| 136 | + # ====== Phase 2: Refinement ====== |
| 137 | + for (iter in seq_len(max_iter)) { |
| 138 | + if (verbose) message(sprintf("=== Refinement iteration %d/%d ===", iter, max_iter)) |
| 139 | + |
| 140 | + ref_prompt <- build_refinement_prompt(results, summary_info, k) |
| 141 | + |
| 142 | + # Refinement always uses text/JSON mode (no native tools) for reliable parsing |
| 143 | + ref_content <- query_llm( |
| 144 | + prompt = ref_prompt, |
| 145 | + system_prompt = "You are an expert in statistical disclosure control. Always respond with valid JSON only, no markdown or prose.", |
| 146 | + provider = provider, model = model, api_key = api_key, |
| 147 | + base_url = base_url |
| 148 | + ) |
| 149 | + # Handle query_llm returning a list (native mode) vs character (text mode) |
| 150 | + if (is.list(ref_content)) ref_content <- ref_content$content |
| 151 | + refined <- tryCatch( |
| 152 | + parse_single_strategy_json(ref_content), |
| 153 | + error = function(e) { |
| 154 | + if (verbose) message(" Could not parse refinement response, skipping.") |
| 155 | + NULL |
| 156 | + } |
| 157 | + ) |
| 158 | + if (is.null(refined)) next |
| 159 | + |
| 160 | + sdcObj_copy <- sdcObj |
| 161 | + sdcObj_copy <- tryCatch( |
| 162 | + execute_tool_calls(sdcObj_copy, refined$calls), |
| 163 | + error = function(e) { |
| 164 | + if (verbose) message(sprintf(" Refinement failed: %s", e$message)) |
| 165 | + NULL |
| 166 | + } |
| 167 | + ) |
| 168 | + if (is.null(sdcObj_copy)) next |
| 169 | + |
| 170 | + sdcObj_copy <- localSuppression(sdcObj_copy, k = k) |
| 171 | + score <- ai_utility_score(sdcObj, sdcObj_copy, weights) |
| 172 | + |
| 173 | + if (verbose) { |
| 174 | + message(sprintf(" U=%.4f (S=%.4f, C=%.4f, IL1=%.4f)", |
| 175 | + score$total, score$suppression_rate, |
| 176 | + score$category_loss, score$il1)) |
| 177 | + } |
| 178 | + |
| 179 | + result_entry <- list(strategy = refined, sdcObj = sdcObj_copy, score = score) |
| 180 | + results[[length(results) + 1]] <- result_entry |
| 181 | + |
| 182 | + if (score$total < best$score$total) { |
| 183 | + best <- result_entry |
| 184 | + if (verbose) message(" -> New best!") |
| 185 | + } |
| 186 | + } |
| 187 | + |
| 188 | + # ====== Show result + confirm ====== |
| 189 | + if (verbose) { |
| 190 | + message(sprintf("\n=== Best strategy: '%s' (U=%.4f) ===", |
| 191 | + if (!is.null(best$strategy$name)) best$strategy$name else "best", |
| 192 | + best$score$total)) |
| 193 | + if (!is.null(best$strategy$reasoning) && nzchar(best$strategy$reasoning)) { |
| 194 | + cat(strwrap(best$strategy$reasoning, width = 75), sep = "\n") |
| 195 | + cat("\n") |
| 196 | + } |
| 197 | + message(sprintf(" Suppression rate: %.4f", best$score$suppression_rate)) |
| 198 | + message(sprintf(" Category loss: %.4f", best$score$category_loss)) |
| 199 | + message(sprintf(" IL1: %.4f", best$score$il1)) |
| 200 | + } |
| 201 | + |
| 202 | + if (confirm && interactive()) { |
| 203 | + answer <- readline("Apply this strategy? [Y/n/q] ") |
| 204 | + answer <- tolower(trimws(answer)) |
| 205 | + if (answer %in% c("q", "quit")) { |
| 206 | + message("Aborted by user.") |
| 207 | + return(sdcObj) |
| 208 | + } |
| 209 | + if (answer %in% c("n", "no")) { |
| 210 | + message("Rejected. Returning original sdcObj unchanged.") |
| 211 | + return(sdcObj) |
| 212 | + } |
| 213 | + } |
| 214 | + |
| 215 | + sdcObj <- best$sdcObj |
| 216 | + |
| 217 | + if (generateReport) { |
| 218 | + report(sdcObj, filename = "anonymization_internal.html", internal = TRUE) |
| 219 | + report(sdcObj, filename = "anonymization_external.html", internal = FALSE) |
| 220 | + } |
| 221 | + |
| 222 | + return(sdcObj) |
| 223 | +} |
0 commit comments