|
| 1 | +package com.predic8.membrane.core.interceptor.mcp; |
| 2 | + |
| 3 | +import com.fasterxml.jackson.core.JsonProcessingException; |
| 4 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 5 | +import com.predic8.membrane.core.exchange.AbstractExchange; |
| 6 | +import com.predic8.membrane.core.interceptor.mcp.MCPUtil.InvalidToolArgumentsException; |
| 7 | +import com.predic8.membrane.core.mcp.MCPToolsCall; |
| 8 | +import com.predic8.membrane.core.mcp.MCPToolsCallResponse; |
| 9 | +import org.jetbrains.annotations.Nullable; |
| 10 | + |
| 11 | +import java.io.IOException; |
| 12 | +import java.nio.charset.StandardCharsets; |
| 13 | +import java.util.ArrayList; |
| 14 | +import java.util.LinkedHashMap; |
| 15 | +import java.util.List; |
| 16 | +import java.util.Map; |
| 17 | +import java.util.Set; |
| 18 | +import java.util.UUID; |
| 19 | + |
| 20 | +import static com.predic8.membrane.core.interceptor.mcp.ExchangeUtils.matchesExchangeFilter; |
| 21 | +import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.getOptionalBooleanArgument; |
| 22 | +import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.getOptionalIntArgument; |
| 23 | +import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.getOptionalSizeArgument; |
| 24 | +import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.getOptionalStringArgument; |
| 25 | +import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.getRequiredLongArgument; |
| 26 | +import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.rejectUnexpectedArguments; |
| 27 | +import static com.predic8.membrane.core.interceptor.mcp.McpSchemaBuilder.integer; |
| 28 | +import static com.predic8.membrane.core.interceptor.mcp.McpSchemaBuilder.string; |
| 29 | +import static java.lang.Integer.MAX_VALUE; |
| 30 | + |
| 31 | +final class ExchangeToolSupport { |
| 32 | + |
| 33 | + static final String ARG_ID = "id"; |
| 34 | + static final String ARG_LIMIT = "limit"; |
| 35 | + static final String ARG_OFFSET = "offset"; |
| 36 | + static final String ARG_INCLUDE_BODIES = "includeBodies"; |
| 37 | + static final String ARG_HOST = "host"; |
| 38 | + static final String ARG_PORT = "port"; |
| 39 | + static final String ARG_PATH_PATTERN = "pathPattern"; |
| 40 | + static final String ARG_MAX_RESPONSE_SIZE = "maxResponseSize"; |
| 41 | + |
| 42 | + private static final ObjectMapper OM = new ObjectMapper(); |
| 43 | + |
| 44 | + private final McpPayloadSanitizer payloadSanitizer; |
| 45 | + |
| 46 | + ExchangeToolSupport(McpPayloadSanitizer payloadSanitizer) { |
| 47 | + this.payloadSanitizer = payloadSanitizer; |
| 48 | + } |
| 49 | + |
| 50 | + ExchangeQuery parseQuery(MCPToolsCall call, int maxExchanges) { |
| 51 | + rejectUnexpectedArguments(call, Set.of( |
| 52 | + ARG_LIMIT, |
| 53 | + ARG_OFFSET, |
| 54 | + ARG_INCLUDE_BODIES, |
| 55 | + ARG_HOST, |
| 56 | + ARG_PORT, |
| 57 | + ARG_PATH_PATTERN, |
| 58 | + ARG_MAX_RESPONSE_SIZE |
| 59 | + )); |
| 60 | + |
| 61 | + return new ExchangeQuery( |
| 62 | + getOptionalStringArgument(call, ARG_HOST), |
| 63 | + getOptionalPort(call), |
| 64 | + getOptionalStringArgument(call, ARG_PATH_PATTERN), |
| 65 | + getOptionalIntArgument(call, ARG_OFFSET, 0, 0, MAX_VALUE), |
| 66 | + getOptionalIntArgument(call, ARG_LIMIT, maxExchanges, 1, maxExchanges), |
| 67 | + getOptionalBooleanArgument(call, ARG_INCLUDE_BODIES, false), |
| 68 | + getOptionalMaxResponseSize(call) |
| 69 | + ); |
| 70 | + } |
| 71 | + |
| 72 | + ExchangeLookupQuery parseLookupQuery(MCPToolsCall call) { |
| 73 | + rejectUnexpectedArguments(call, Set.of( |
| 74 | + ARG_ID, |
| 75 | + ARG_INCLUDE_BODIES |
| 76 | + )); |
| 77 | + |
| 78 | + return new ExchangeLookupQuery( |
| 79 | + getRequiredLongArgument(call, ARG_ID), |
| 80 | + getOptionalBooleanArgument(call, ARG_INCLUDE_BODIES, false) |
| 81 | + ); |
| 82 | + } |
| 83 | + |
| 84 | + Map<String, Object> getExchangesSchema(int maxExchanges) { |
| 85 | + return McpSchemaBuilder.object() |
| 86 | + .property(ARG_LIMIT, integer().minimum(1).maximum(maxExchanges)) |
| 87 | + .property(ARG_OFFSET, integer().minimum(0).description("Number of newest matching exchanges to skip before collecting the page")) |
| 88 | + .property(ARG_INCLUDE_BODIES, McpSchemaBuilder.bool()) |
| 89 | + .property(ARG_HOST, string()) |
| 90 | + .property(ARG_PORT, integer().minimum(1).maximum(65535)) |
| 91 | + .property(ARG_PATH_PATTERN, string().description("Matches by prefix or regex")) |
| 92 | + .property(ARG_MAX_RESPONSE_SIZE, integer().minimum(1).description("Maximum size in bytes of the final JSON-RPC response body returned by this tool")) |
| 93 | + .additionalProperties(false) |
| 94 | + .build(); |
| 95 | + } |
| 96 | + |
| 97 | + Map<String, Object> getExchangeSchema() { |
| 98 | + return McpSchemaBuilder.object() |
| 99 | + .property(ARG_ID, integer().description("Exchange id")) |
| 100 | + .property(ARG_INCLUDE_BODIES, McpSchemaBuilder.bool()) |
| 101 | + .required(ARG_ID) |
| 102 | + .additionalProperties(false) |
| 103 | + .build(); |
| 104 | + } |
| 105 | + |
| 106 | + ExchangePage findPage(@Nullable List<AbstractExchange> allExchanges, ExchangeQuery query) { |
| 107 | + List<AbstractExchange> exchanges = allExchanges == null ? List.of() : allExchanges; |
| 108 | + List<AbstractExchange> page = new ArrayList<>(query.limit()); |
| 109 | + int skipped = 0; |
| 110 | + boolean hasMore = false; |
| 111 | + |
| 112 | + for (int i = exchanges.size() - 1; i >= 0; i--) { |
| 113 | + AbstractExchange exchange = exchanges.get(i); |
| 114 | + if (exchange.getResponse() == null) { |
| 115 | + continue; |
| 116 | + } |
| 117 | + if (!matchesExchangeFilter(exchange, query.host(), query.port(), query.pathPattern())) { |
| 118 | + continue; |
| 119 | + } |
| 120 | + if (skipped < query.offset()) { |
| 121 | + skipped++; |
| 122 | + continue; |
| 123 | + } |
| 124 | + if (page.size() < query.limit()) { |
| 125 | + page.addFirst(exchange); |
| 126 | + continue; |
| 127 | + } |
| 128 | + |
| 129 | + hasMore = true; |
| 130 | + break; |
| 131 | + } |
| 132 | + |
| 133 | + return new ExchangePage(page, hasMore, query.offset()); |
| 134 | + } |
| 135 | + |
| 136 | + MCPToolsCallResponse buildFullPageResponse(MCPToolsCall call, ExchangePage page, boolean includeBodies) { |
| 137 | + List<Map<String, Object>> describedExchanges = describeExchanges(page.exchanges(), includeBodies); |
| 138 | + return createExchangePageResponse( |
| 139 | + call, |
| 140 | + describedExchanges, |
| 141 | + page.hasMore(), |
| 142 | + nextOffset(page.offset(), describedExchanges.size(), page.hasMore()) |
| 143 | + ); |
| 144 | + } |
| 145 | + |
| 146 | + MCPToolsCallResponse buildSizedPageResponse(MCPToolsCall call, ExchangePage page, boolean includeBodies, int maxResponseSize) { |
| 147 | + TextResponseEnvelope responseEnvelope = measureTextResponseEnvelope(call); |
| 148 | + long maxResponseSizeLimit = maxResponseSize; |
| 149 | + long prefixBytes = measureEscapedJsonStringContentSize("{\"exchanges\":["); |
| 150 | + long separatorBytes = measureEscapedJsonStringContentSize(","); |
| 151 | + long minimumResponseSize = responseEnvelope.fixedBytes() + prefixBytes + measureExchangePageSuffixBytes(false, null); |
| 152 | + |
| 153 | + if (minimumResponseSize > maxResponseSizeLimit) throw new InvalidToolArgumentsException("Tool argument '" + ARG_MAX_RESPONSE_SIZE + "' must be at least " + minimumResponseSize + " bytes"); |
| 154 | + |
| 155 | + List<Map<String, Object>> describedExchanges = new ArrayList<>(); |
| 156 | + long exchangesBytes = 0; |
| 157 | + for (int i = page.exchanges().size() - 1; i >= 0; i--) { |
| 158 | + Map<String, Object> description = describeExchangeOrThrow(page.exchanges().get(i), includeBodies); |
| 159 | + long additionalExchangeBytes = measureExchangeBytes(description, describedExchanges.isEmpty(), separatorBytes); |
| 160 | + |
| 161 | + boolean hasMore = page.hasMore() || i > 0; |
| 162 | + Integer candidateNextOffset = nextOffset(page.offset(), describedExchanges.size() + 1, hasMore); |
| 163 | + long trackedSize = responseEnvelope.fixedBytes() |
| 164 | + + prefixBytes |
| 165 | + + exchangesBytes |
| 166 | + + additionalExchangeBytes |
| 167 | + + measureExchangePageSuffixBytes(hasMore, candidateNextOffset); |
| 168 | + if (trackedSize > maxResponseSizeLimit) { |
| 169 | + if (describedExchanges.isEmpty()) { |
| 170 | + throw new InvalidToolArgumentsException("Tool argument '" + ARG_MAX_RESPONSE_SIZE + "' must be at least " + trackedSize + " bytes to return the next exchange page"); |
| 171 | + } |
| 172 | + break; |
| 173 | + } |
| 174 | + |
| 175 | + describedExchanges.addFirst(description); |
| 176 | + exchangesBytes += additionalExchangeBytes; |
| 177 | + } |
| 178 | + |
| 179 | + boolean hasMore = page.hasMore() || describedExchanges.size() < page.exchanges().size(); |
| 180 | + return createExchangePageResponse( |
| 181 | + call, |
| 182 | + describedExchanges, |
| 183 | + hasMore, |
| 184 | + nextOffset(page.offset(), describedExchanges.size(), hasMore) |
| 185 | + ); |
| 186 | + } |
| 187 | + |
| 188 | + MCPToolsCallResponse buildSingleExchangeResponse(MCPToolsCall call, long exchangeId, @Nullable AbstractExchange exchange, boolean includeBodies) { |
| 189 | + if (exchange == null) { |
| 190 | + return MCPToolsCallResponse.toolError(call, "Exchange with id " + exchangeId + " was not found"); |
| 191 | + } |
| 192 | + |
| 193 | + Map<String, Object> description = MCPUtil.describeExchange(exchange, includeBodies, payloadSanitizer); |
| 194 | + if (description == null) { |
| 195 | + return MCPToolsCallResponse.toolError(call, "Exchange with id " + exchangeId + " has no response yet"); |
| 196 | + } |
| 197 | + |
| 198 | + return MCPToolsCallResponse.from(call) |
| 199 | + .withJson(Map.of("exchange", description)); |
| 200 | + } |
| 201 | + |
| 202 | + private List<Map<String, Object>> describeExchanges(List<AbstractExchange> exchanges, boolean includeBodies) { |
| 203 | + return exchanges.stream() |
| 204 | + .map(exchange -> describeExchangeOrThrow(exchange, includeBodies)) |
| 205 | + .toList(); |
| 206 | + } |
| 207 | + |
| 208 | + private Map<String, Object> describeExchangeOrThrow(AbstractExchange exchange, boolean includeBodies) { |
| 209 | + Map<String, Object> description = MCPUtil.describeExchange(exchange, includeBodies, payloadSanitizer); |
| 210 | + if (description == null) { |
| 211 | + throw new IllegalStateException("Expected exchange response data to be present for paging"); |
| 212 | + } |
| 213 | + return description; |
| 214 | + } |
| 215 | + |
| 216 | + private MCPToolsCallResponse createExchangePageResponse(MCPToolsCall call, List<Map<String, Object>> exchanges, boolean hasMore, @Nullable Integer nextOffset) { |
| 217 | + return MCPToolsCallResponse.from(call) |
| 218 | + .withJson(buildExchangePagePayload(exchanges, hasMore, nextOffset)); |
| 219 | + } |
| 220 | + |
| 221 | + private Map<String, Object> buildExchangePagePayload(List<Map<String, Object>> exchanges, boolean hasMore, @Nullable Integer nextOffset) { |
| 222 | + LinkedHashMap<String, Object> payload = new LinkedHashMap<>(); |
| 223 | + payload.put("exchanges", exchanges); |
| 224 | + payload.put("hasMore", hasMore); |
| 225 | + if (nextOffset != null) { |
| 226 | + payload.put("nextOffset", nextOffset); |
| 227 | + } |
| 228 | + return payload; |
| 229 | + } |
| 230 | + |
| 231 | + private @Nullable Integer nextOffset(int offset, int returnedCount, boolean hasMore) { |
| 232 | + return hasMore ? offset + returnedCount : null; |
| 233 | + } |
| 234 | + |
| 235 | + private long measureExchangeBytes(Map<String, Object> description, boolean firstExchange, long separatorBytes) { |
| 236 | + return measureEscapedJsonStringContentSize(serializeJson(description)) + (firstExchange ? 0 : separatorBytes); |
| 237 | + } |
| 238 | + |
| 239 | + // Measure the fixed JSON-RPC/MCP wrapper once with a placeholder so the byte limit |
| 240 | + // applies to the final serialized response body, not just the unescaped payload text. |
| 241 | + private TextResponseEnvelope measureTextResponseEnvelope(MCPToolsCall call) { |
| 242 | + String marker = "__MEMBRANE_MCP_TEXT_PLACEHOLDER_" + UUID.randomUUID() + "__"; |
| 243 | + try { |
| 244 | + String responseJson = MCPToolsCallResponse.from(call).withText(marker).toJson(); |
| 245 | + int markerIndex = responseJson.indexOf(marker); |
| 246 | + if (markerIndex < 0) { |
| 247 | + throw new IllegalStateException("Could not locate placeholder marker in serialized MCP response"); |
| 248 | + } |
| 249 | + |
| 250 | + return new TextResponseEnvelope( |
| 251 | + utf8Size(responseJson.substring(0, markerIndex)), |
| 252 | + utf8Size(responseJson.substring(markerIndex + marker.length())) |
| 253 | + ); |
| 254 | + } catch (IOException e) { |
| 255 | + throw new RuntimeException("Failed to serialize MCP response envelope", e); |
| 256 | + } |
| 257 | + } |
| 258 | + |
| 259 | + private long measureExchangePageSuffixBytes(boolean hasMore, @Nullable Integer nextOffset) { |
| 260 | + return measureEscapedJsonStringContentSize(buildExchangePageSuffix(hasMore, nextOffset)); |
| 261 | + } |
| 262 | + |
| 263 | + private String buildExchangePageSuffix(boolean hasMore, @Nullable Integer nextOffset) { |
| 264 | + return "],\"hasMore\":" + hasMore + (nextOffset == null ? "" : ",\"nextOffset\":" + nextOffset) + "}"; |
| 265 | + } |
| 266 | + |
| 267 | + private String serializeJson(Object value) { |
| 268 | + try { |
| 269 | + return OM.writeValueAsString(value); |
| 270 | + } catch (JsonProcessingException e) { |
| 271 | + throw new RuntimeException("Failed to serialize JSON value", e); |
| 272 | + } |
| 273 | + } |
| 274 | + |
| 275 | + private long measureEscapedJsonStringContentSize(String value) { |
| 276 | + try { |
| 277 | + return OM.writeValueAsBytes(value).length - 2; |
| 278 | + } catch (JsonProcessingException e) { |
| 279 | + throw new RuntimeException("Failed to serialize JSON string content", e); |
| 280 | + } |
| 281 | + } |
| 282 | + |
| 283 | + private static long utf8Size(String value) { |
| 284 | + return value.getBytes(StandardCharsets.UTF_8).length; |
| 285 | + } |
| 286 | + |
| 287 | + private static @Nullable Integer getOptionalPort(MCPToolsCall call) { |
| 288 | + return call.getArgument(ARG_PORT) == null ? null : getOptionalIntArgument(call, ARG_PORT, -1, 1, 65535); |
| 289 | + } |
| 290 | + |
| 291 | + private static @Nullable Integer getOptionalMaxResponseSize(MCPToolsCall call) { |
| 292 | + return call.getArgument(ARG_MAX_RESPONSE_SIZE) == null ? null : getOptionalSizeArgument(call, ARG_MAX_RESPONSE_SIZE, -1, 1, MAX_VALUE); |
| 293 | + } |
| 294 | + |
| 295 | + record ExchangeQuery( |
| 296 | + @Nullable String host, |
| 297 | + @Nullable Integer port, |
| 298 | + @Nullable String pathPattern, |
| 299 | + int offset, |
| 300 | + int limit, |
| 301 | + boolean includeBodies, |
| 302 | + @Nullable Integer maxResponseSize |
| 303 | + ) { |
| 304 | + } |
| 305 | + |
| 306 | + record ExchangeLookupQuery(long id, boolean includeBodies) { |
| 307 | + } |
| 308 | + |
| 309 | + record ExchangePage(List<AbstractExchange> exchanges, boolean hasMore, int offset) { |
| 310 | + } |
| 311 | + |
| 312 | + private record TextResponseEnvelope(long prefixBytes, long suffixBytes) { |
| 313 | + private long fixedBytes() { |
| 314 | + return prefixBytes + suffixBytes; |
| 315 | + } |
| 316 | + } |
| 317 | +} |
0 commit comments