From cf84c3b7f79f7229934a076f48db6465edbc58f8 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 7 May 2026 15:21:59 +0200 Subject: [PATCH 01/43] feat: add AI API usage tracking and management components --- .../yaml/parsing/binding/ObjectBinder.java | 17 +++-- .../membrane/core/interceptor/ai/AiUtil.java | 55 +++++++++++++++ .../core/interceptor/ai/OpenAiApiUtil.java | 70 +++++++++++++++++++ .../core/interceptor/ai/store/AiApiLimit.java | 57 +++++++++++++++ .../core/interceptor/ai/store/AiApiStore.java | 18 +++++ .../core/interceptor/ai/store/AiApiUser.java | 29 ++++++++ .../ai/store/JDBCAiApiUsageStore.java | 70 +++++++++++++++++++ .../ai/store/SimpleAiApiStore.java | 53 ++++++++++++++ .../core/interceptor/ai/store/Usage.java | 3 + .../core/security/AbstractSecurityScheme.java | 7 +- .../core/util/jdbc/AbstractJdbcSupport.java | 27 +++++-- 11 files changed, 394 insertions(+), 12 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/Usage.java diff --git a/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ObjectBinder.java b/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ObjectBinder.java index a018d1b509..d08f99d703 100644 --- a/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ObjectBinder.java +++ b/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ObjectBinder.java @@ -33,11 +33,7 @@ import java.util.List; import java.util.Objects; -import static com.predic8.membrane.annot.yaml.McYamlIntrospector.findRequiredSetters; -import static com.predic8.membrane.annot.yaml.McYamlIntrospector.findSingleSetterOrNullForAnnotation; -import static com.predic8.membrane.annot.yaml.McYamlIntrospector.getSingleChildSetter; -import static com.predic8.membrane.annot.yaml.McYamlIntrospector.isCollapsed; -import static com.predic8.membrane.annot.yaml.McYamlIntrospector.isNoEnvelope; +import static com.predic8.membrane.annot.yaml.McYamlIntrospector.*; import static com.predic8.membrane.annot.yaml.NodeValidationUtils.ensureMappingStart; public final class ObjectBinder { @@ -49,7 +45,8 @@ public final class ObjectBinder { public static T bind(ParsingContext pc, Class clazz, JsonNode node) throws ConfigurationParsingException { try { - T configObj = clazz.getConstructor().newInstance(); + T configObj = instantiate(clazz); + BeanDefinition currentBeanDefinition = BeanDefinitionContext.current(); if (currentBeanDefinition != null && pc.getRegistry() != null) { pc.getRegistry().rememberBeanDefinition(configObj, currentBeanDefinition); @@ -102,6 +99,14 @@ public static T bind(ParsingContext pc, Class clazz, JsonNode node) th } } + private static @NotNull T instantiate(Class clazz) throws InvocationTargetException, InstantiationException, IllegalAccessException { + try { + return clazz.getConstructor().newInstance(); + } catch (NoSuchMethodException e) { + throw new ConfigurationParsingException("Class %s does not have a public no-arg constructor.".formatted(clazz.getName())); + } + } + private static @NotNull T handleCollapsed(ParsingContext ctx, Class clazz, JsonNode node, T configObj) { if (node.isNull()) throw new ConfigurationParsingException("Collapsed element must not be null."); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java new file mode 100644 index 0000000000..acba05bd67 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java @@ -0,0 +1,55 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.predic8.membrane.core.http.Header; + +public class AiUtil { + + public static final String BEARER_PREFIX = "Bearer"; + + private AiUtil() {} + + /** + * Estimates the number of tokens in a given text. + * The calculation assumes an average token length of 4 characters. + * + * Content Approximation + * English prose chars / 4 + * German/French chars / 3.5 + * JSON/XML/code chars / 2.5–3 + * Chinese/Japanese very different + * + * For API gateways, quotas, billing alerts, or rate limiting, approximate counting is often sufficient. + * + * @param text the input string whose tokens are to be estimated + * @return the estimated number of tokens, rounded up to the nearest integer + */ + public static int estimateTokens(String text) { + return (int) Math.ceil(text.length() / 4.0); + } + + /** + * Extracts the Bearer token from the Authorization header. + * If the Authorization header is null or does not contain + * a Bearer token, this method returns null. + * + * @param header the Header object from which the Authorization + * header is to be extracted + * @return the Bearer token as a String if present; otherwise null + */ + public static String extractBearerToken(Header header) { + var ah = header.getAuthorization(); + if (ah == null) { + return null; + } + + int index = ah.indexOf(BEARER_PREFIX); + if (index < 0) { + return null; + } + + var token = ah.substring(index + BEARER_PREFIX.length()).trim(); + + return token.isEmpty() ? null : token; + } + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java new file mode 100644 index 0000000000..7a9c661516 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java @@ -0,0 +1,70 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.predic8.membrane.core.http.Response; + +import static com.predic8.membrane.core.http.Header.WWW_AUTHENTICATE; +import static com.predic8.membrane.core.http.Response.badRequest; +import static com.predic8.membrane.core.http.Response.unauthorized; + +public class OpenAiApiUtil { + + private static final ObjectMapper om = new ObjectMapper(); + + public static Response authenticationFailed() { + return unauthorized().header(WWW_AUTHENTICATE, "Bearer").json(createJson(new ErrorEnvelope( + new ErrorBody( + "Invalid authentication credentials", + "invalid_request_error", + null, + "invalid_authentication" + ) + ))).build(); + } + + public static Response contextLengthExceeded(int maxTokens, int estimatedTokens) { + return badRequest().json(createJson(new ErrorBody( + """ + This model's maximum context length is %d tokens. + Your request contains approximately %d tokens. + """.formatted(maxTokens, estimatedTokens).trim(), + "invalid_request_error", + "input", + "context_length_exceeded" + ))).build(); + } + + public static Response tokenLimitExceeded() { + return badRequest() + .json(createJson(new ErrorEnvelope( + new ErrorBody( + "Token rate limit exceeded.", + "rate_limit_error", + null, + "token_limit_exceeded" + ) + ))) + .build(); + } + + public static String createJson(Object o) { + try { + return om.writeValueAsString(o); + } catch (Exception e) { + return """ + { "error": "Could not create JSON" } + """; + } + } + + record ErrorEnvelope(ErrorBody error) { + } + + record ErrorBody( + String message, + String type, + String param, + String code + ) { + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java new file mode 100644 index 0000000000..925c38ca1a --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java @@ -0,0 +1,57 @@ +package com.predic8.membrane.core.interceptor.ai.store; + +import com.predic8.membrane.annot.MCAttribute; +import com.predic8.membrane.annot.MCElement; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Instant; + +import static java.time.Instant.now; + +@MCElement(name = "limit", component = false, id = "ai-api-limit") +public class AiApiLimit { + + private static final Logger log = LoggerFactory.getLogger(AiApiLimit.class); + + private int maxTokens; + private int period; + private Instant nextReset; + private long tokens; + + public AiApiLimit() { + nextReset = now().plusSeconds(period); + } + + public long checkLimit() { + if (now().isAfter(nextReset)) { + tokens = 0; + nextReset = now().plusSeconds(period); + log.debug("Resetting AI API usage limit."); + } + return maxTokens - tokens; + } + + public void addTokens(long tokens) { + log.debug("Adding {} tokens to AI API usage limit.", tokens); + this.tokens += tokens; + } + + public int getMaxTokens() { + return maxTokens; + } + + @MCAttribute + public void setMaxTokens(int maxTokens) { + this.maxTokens = maxTokens; + } + + public int getPeriod() { + return period; + } + + @MCAttribute + public void setPeriod(int period) { + this.period = period; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java new file mode 100644 index 0000000000..41eafc6de5 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java @@ -0,0 +1,18 @@ +package com.predic8.membrane.core.interceptor.ai.store; + +import com.predic8.membrane.core.router.Router; + +import java.util.Optional; + +public interface AiApiStore { + + default void init(Router router) { + } + + void store(String user, Usage usage); + + Optional getUser(String token); + + long checkLimit(AiApiUser user); +} + diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java new file mode 100644 index 0000000000..61125c1d05 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java @@ -0,0 +1,29 @@ +package com.predic8.membrane.core.interceptor.ai.store; + +import com.predic8.membrane.annot.MCAttribute; +import com.predic8.membrane.annot.MCElement; + +@MCElement(name = "users", component = false, id="ai-api-users") +public class AiApiUser { + + private String name; + private String token; + + public String getName() { + return name; + } + + @MCAttribute() + public void setName(String name) { + this.name = name; + } + + public String getToken() { + return token; + } + + @MCAttribute() + public void setToken(String token) { + this.token = token; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java new file mode 100644 index 0000000000..541ed27d10 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java @@ -0,0 +1,70 @@ +package com.predic8.membrane.core.interceptor.ai.store; + +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.router.Router; +import com.predic8.membrane.core.util.jdbc.AbstractJdbcSupport; + +import java.sql.SQLException; +import java.util.Optional; + +@MCElement(name = "jdbcAiApiUsageStore") +public class JDBCAiApiUsageStore extends AbstractJdbcSupport implements AiApiStore { + + private static final String CREATE_TABLE_SQL = """ + CREATE TABLE IF NOT EXISTS ai_api_usage ( + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + username VARCHAR(255) NOT NULL, + input_tokens INT NOT NULL, + output_tokens INT NOT NULL, + total_tokens INT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """; + + private static final String INSERT_SQL = """ + INSERT INTO ai_api_usage ( + username, + input_tokens, + output_tokens, + total_tokens + ) VALUES (?, ?, ?, ?) + """; + + @Override + public void init(Router router) { + super.init(router); + createTablesIfNotExist(); + } + + @Override + public void store(String user, com.predic8.membrane.core.interceptor.ai.store.Usage usage) { + try (var connection = getConnection(); var ps = connection.prepareStatement(INSERT_SQL)) { + ps.setString(1, user); + ps.setInt(2, usage.inputTokens()); + ps.setInt(3, usage.outputTokens()); + ps.setInt(4, usage.totalTokens()); + + ps.executeUpdate(); + } catch (SQLException e) { + throw new RuntimeException("Could not store AI API usage.", e); + } + } + + @Override + public Optional getUser(String token) { + return Optional.empty(); + } + + @Override + public long checkLimit(AiApiUser user) { + return 0; + } + + private void createTablesIfNotExist() { + try (var connection = getConnection(); var ps = connection.prepareStatement(CREATE_TABLE_SQL)) { + ps.executeUpdate(); + } catch (SQLException e) { + throw new RuntimeException("Could not create AI API usage table.", e); + } + } +} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java new file mode 100644 index 0000000000..6d13716360 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -0,0 +1,53 @@ +package com.predic8.membrane.core.interceptor.ai.store; + +import com.predic8.membrane.annot.MCChildElement; +import com.predic8.membrane.annot.MCElement; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Optional; + +@MCElement(name="simpleStore",component = false, noEnvelope = false, id="simple-ai-api-store") +public class SimpleAiApiStore implements AiApiStore { + + private static final Logger log = LoggerFactory.getLogger(SimpleAiApiStore.class); + + private List users; + private AiApiLimit limit = new AiApiLimit(); + + @Override + public void store(String user, Usage usage) { + log.info("User: {} Usage: {}", user, usage); + limit.addTokens(usage.totalTokens()); + } + + @Override + public Optional getUser(String token) { + return users.stream().filter(u -> u.getToken().equals(token)).findFirst(); + } + + @Override + public long checkLimit(AiApiUser user) { + return limit.checkLimit(); + } + + @MCChildElement(allowForeign = true,order = 10) + public void setUsers(List users) { + this.users = users; + } + + public List getUsers() { + return users; + } + + public AiApiLimit getLimit() { + return limit; + } + + @MCChildElement(allowForeign = true) + public void setLimit(AiApiLimit limit) { + this.limit = limit; + } +} + diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/Usage.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/Usage.java new file mode 100644 index 0000000000..9288bba508 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/Usage.java @@ -0,0 +1,3 @@ +package com.predic8.membrane.core.interceptor.ai.store; + +public record Usage(int inputTokens, int outputTokens, int totalTokens) {} diff --git a/core/src/main/java/com/predic8/membrane/core/security/AbstractSecurityScheme.java b/core/src/main/java/com/predic8/membrane/core/security/AbstractSecurityScheme.java index b936ec3385..72beab6dbb 100644 --- a/core/src/main/java/com/predic8/membrane/core/security/AbstractSecurityScheme.java +++ b/core/src/main/java/com/predic8/membrane/core/security/AbstractSecurityScheme.java @@ -13,7 +13,7 @@ limitations under the License. */ package com.predic8.membrane.core.security; -import com.predic8.membrane.core.exchange.*; +import com.predic8.membrane.core.exchange.Exchange; import java.util.*; @@ -58,4 +58,9 @@ public boolean hasScope(String scope) { public Set getScopes() { return scopes; } + + @Override + public String getPrincipal() { + return null; + } } diff --git a/core/src/main/java/com/predic8/membrane/core/util/jdbc/AbstractJdbcSupport.java b/core/src/main/java/com/predic8/membrane/core/util/jdbc/AbstractJdbcSupport.java index 6103666e1b..df865bbda5 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/jdbc/AbstractJdbcSupport.java +++ b/core/src/main/java/com/predic8/membrane/core/util/jdbc/AbstractJdbcSupport.java @@ -14,12 +14,16 @@ package com.predic8.membrane.core.util.jdbc; -import com.predic8.membrane.annot.*; -import com.predic8.membrane.core.router.*; -import com.predic8.membrane.core.util.*; +import com.predic8.membrane.annot.MCAttribute; +import com.predic8.membrane.core.router.Router; +import com.predic8.membrane.core.util.ConfigurationException; -import javax.sql.*; -import java.util.*; +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Map; + +import static com.predic8.membrane.core.util.ExceptionUtil.getRootCause; public abstract class AbstractJdbcSupport { @@ -53,6 +57,19 @@ public void init(Router router) { getDatasourceIfNull(); } + // @TODO make subclasses use this method + public Connection getConnection() { + try { + return datasource.getConnection(); + } catch (SQLException e) { + var root = getRootCause(e); + if (root instanceof ClassNotFoundException) { + throw new ConfigurationException("JDBC driver not found. Please add the JDBC driver to the classpath: " + root.getMessage()); + } + throw new RuntimeException(e); + } + } + private void getDatasourceIfNull() { if (datasource != null) return; From d792f5fb24205a254c32bacbb998a437a724e1d5 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 7 May 2026 15:24:07 +0200 Subject: [PATCH 02/43] feat: implement OpenAI API interceptor with usage tracking and token limits --- .../yaml/parsing/binding/ObjectBinder.java | 1 - .../membrane/core/interceptor/ai/AiUtil.java | 6 +- .../interceptor/ai/OpenAIAPIInterceptor.java | 178 ++++++++++++++++++ .../ai/store/SimpleAiApiStore.java | 2 +- .../core/security/AbstractSecurityScheme.java | 4 - 5 files changed, 182 insertions(+), 9 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java diff --git a/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ObjectBinder.java b/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ObjectBinder.java index d08f99d703..8ca91e4ebb 100644 --- a/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ObjectBinder.java +++ b/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ObjectBinder.java @@ -122,7 +122,6 @@ private static T handleNoEnvelopeList(ParsingContext pc, Class clazz, return configObj; } - @SuppressWarnings("ConstantValue") private static void applyCollapsedScalar(Class clazz, JsonNode node, T target) { Method attributeSetter = findSingleSetterOrNullForAnnotation(clazz, MCAttribute.class); Method textSetter = findSingleSetterOrNullForAnnotation(clazz, MCTextContent.class); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java index acba05bd67..5c134d0ae9 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java @@ -11,15 +11,15 @@ private AiUtil() {} /** * Estimates the number of tokens in a given text. * The calculation assumes an average token length of 4 characters. - * + *

* Content Approximation * English prose chars / 4 * German/French chars / 3.5 * JSON/XML/code chars / 2.5–3 * Chinese/Japanese very different - * + *

* For API gateways, quotas, billing alerts, or rate limiting, approximate counting is often sufficient. - * + *

* @param text the input string whose tokens are to be estimated * @return the estimated number of tokens, rounded up to the nearest integer */ diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java new file mode 100644 index 0000000000..d609f51260 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java @@ -0,0 +1,178 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.annot.MCAttribute; +import com.predic8.membrane.annot.MCChildElement; +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.AbstractInterceptor; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.interceptor.ai.store.AiApiStore; +import com.predic8.membrane.core.interceptor.ai.store.Usage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +import static com.predic8.membrane.core.http.Header.AUTHORIZATION; +import static com.predic8.membrane.core.interceptor.Interceptor.Flow.REQUEST; +import static com.predic8.membrane.core.interceptor.Interceptor.Flow.RESPONSE; +import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; +import static com.predic8.membrane.core.interceptor.Outcome.RETURN; +import static com.predic8.membrane.core.interceptor.ai.AiUtil.estimateTokens; +import static com.predic8.membrane.core.interceptor.ai.AiUtil.extractBearerToken; +import static com.predic8.membrane.core.interceptor.ai.OpenAiApiUtil.*; + +@MCElement(name = "openAI") +public class OpenAIAPIInterceptor extends AbstractInterceptor { + + private static final Logger log = LoggerFactory.getLogger(OpenAIAPIInterceptor.class); + + public static final String MEMBRANE_AI_USERTOKEN = "membrane.ai.usertoken"; + public static final String MAX_OUTPUT_TOKENS = "max_output_tokens"; + + private static final ObjectMapper om = new ObjectMapper(); + + private String apiKey; + private int maxOutputTokens; + private int maxInputTokens; + private AiApiStore store; + + @Override + public void init() { + if (store != null) + store.init(router); + } + + @Override + public Outcome handleRequest(Exchange exc) { + var header = exc.getRequest().getHeader(); + + if (store != null) { + var user = store.getUser(extractBearerToken(header)); + log.debug("User: {}", user); + if (user.isEmpty()) { + exc.setResponse(authenticationFailed()); + return RETURN; + } + var remaining = store.checkLimit(user.get()); + if (remaining <= 0) { + exc.setResponse(tokenLimitExceeded()); + return RETURN; + } + exc.setProperty(MEMBRANE_AI_USERTOKEN, user); + } + + header.removeFields(AUTHORIZATION); + header.add(AUTHORIZATION, "Bearer " + apiKey); + + var json = getJson(exc, REQUEST); + + if (maxOutputTokens != 0) { + json.put(MAX_OUTPUT_TOKENS, maxOutputTokens); + } + + if (maxInputTokens != 0) { + var input = json.get("input"); + if (input != null) { + var estimated = estimateTokens(input.asText()); + if (estimated > maxInputTokens) { + exc.setResponse(contextLengthExceeded(maxInputTokens, estimated)); + return RETURN; + } + } + } + setJsonResponse(exc, json); + return CONTINUE; + } + + private static void setJsonResponse(Exchange exc, ObjectNode json) { + try { + exc.getRequest().setBodyContent(om.writeValueAsBytes(json)); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private static ObjectNode getJson(Exchange exc, Flow flow) { + try { + if (om.readTree(exc.getMessage(flow).getBodyAsStreamDecoded()) instanceof ObjectNode on) { + return on; + } + throw new RuntimeException("Expected JSON Object"); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Outcome handleResponse(Exchange exc) { + var response = exc.getResponse(); + if (!response.isJSON()) { + log.debug("Response is not JSON"); + return CONTINUE; + } + + var json = getJson(exc, RESPONSE); + + // Error from AI API + if (!json.get("error").isNull()) { + return CONTINUE; + } + store.store(exc.getProperty(MEMBRANE_AI_USERTOKEN, String.class), getUsage(json)); + return CONTINUE; + } + + private Usage getUsage(ObjectNode json) { + var usage = json.get("usage"); + if (usage == null) { + return new Usage(0, 0, 0); + } + return new Usage(usage.get("input_tokens").asInt(), usage.get("output_tokens").asInt(), usage.get("total_tokens").asInt()); + } + + + + public String getApiKey() { + return apiKey; + } + + @MCAttribute + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public AiApiStore getAiStore() { + return null; + } + + @MCChildElement(allowForeign = true) + public void setAiStore(AiApiStore store) { + this.store = store; + } + + @Override + public String getDisplayName() { + return "OpenAI API"; + } + + public int getMaxOutputTokens() { + return maxOutputTokens; + } + + @MCAttribute + public void setMaxOutputTokens(int maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + } + + public int getMaxInputTokens() { + return maxInputTokens; + } + + @MCAttribute + public void setMaxInputTokens(int maxInputTokens) { + this.maxInputTokens = maxInputTokens; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index 6d13716360..ccaa19387a 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -8,7 +8,7 @@ import java.util.List; import java.util.Optional; -@MCElement(name="simpleStore",component = false, noEnvelope = false, id="simple-ai-api-store") +@MCElement(name="simpleStore",component = false, id="simple-ai-api-store") public class SimpleAiApiStore implements AiApiStore { private static final Logger log = LoggerFactory.getLogger(SimpleAiApiStore.class); diff --git a/core/src/main/java/com/predic8/membrane/core/security/AbstractSecurityScheme.java b/core/src/main/java/com/predic8/membrane/core/security/AbstractSecurityScheme.java index 72beab6dbb..b3950b727a 100644 --- a/core/src/main/java/com/predic8/membrane/core/security/AbstractSecurityScheme.java +++ b/core/src/main/java/com/predic8/membrane/core/security/AbstractSecurityScheme.java @@ -59,8 +59,4 @@ public Set getScopes() { return scopes; } - @Override - public String getPrincipal() { - return null; - } } From 77def744c659ce37a9c0a90ae73ad232dc8f37e7 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 7 May 2026 15:58:49 +0200 Subject: [PATCH 03/43] feat: improve concurrency for API rate limits and enhance error handling - Introduced thread-safe `AtomicLong` for token management. - Synchronized reset logic in `AiApiLimit`. - Improved error handling and null checks in OpenAI API interactions. - Default-initialized user list in `SimpleAiApiStore`. - Fixed getter for `AiApiStore` in interceptor. --- .../interceptor/ai/OpenAIAPIInterceptor.java | 25 +++++++++------- .../core/interceptor/ai/OpenAiApiUtil.java | 4 +-- .../core/interceptor/ai/store/AiApiLimit.java | 29 ++++++++++++------- .../ai/store/JDBCAiApiUsageStore.java | 2 +- .../ai/store/SimpleAiApiStore.java | 3 +- 5 files changed, 38 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java index d609f51260..073822a536 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java @@ -51,13 +51,14 @@ public Outcome handleRequest(Exchange exc) { var header = exc.getRequest().getHeader(); if (store != null) { - var user = store.getUser(extractBearerToken(header)); - log.debug("User: {}", user); - if (user.isEmpty()) { + var opt = store.getUser(extractBearerToken(header)); + if (opt.isEmpty()) { exc.setResponse(authenticationFailed()); return RETURN; } - var remaining = store.checkLimit(user.get()); + var user = opt.get(); + log.debug("User: {}", user); + var remaining = store.checkLimit(user); if (remaining <= 0) { exc.setResponse(tokenLimitExceeded()); return RETURN; @@ -117,20 +118,22 @@ public Outcome handleResponse(Exchange exc) { var json = getJson(exc, RESPONSE); - // Error from AI API - if (!json.get("error").isNull()) { + // Pass error from AI API to client + if (json.get("error") != null && !json.get("error").isNull()) { return CONTINUE; } - store.store(exc.getProperty(MEMBRANE_AI_USERTOKEN, String.class), getUsage(json)); + if (store != null) { + store.store(exc.getProperty(MEMBRANE_AI_USERTOKEN, String.class), getUsage(json)); + } return CONTINUE; } private Usage getUsage(ObjectNode json) { - var usage = json.get("usage"); - if (usage == null) { + var usage = json.path("usage"); + if (usage.isNull()) { return new Usage(0, 0, 0); } - return new Usage(usage.get("input_tokens").asInt(), usage.get("output_tokens").asInt(), usage.get("total_tokens").asInt()); + return new Usage(usage.path("input_tokens").asInt(), usage.path("output_tokens").asInt(), usage.path("total_tokens").asInt()); } @@ -145,7 +148,7 @@ public void setApiKey(String apiKey) { } public AiApiStore getAiStore() { - return null; + return store; } @MCChildElement(allowForeign = true) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java index 7a9c661516..5d63de213f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java @@ -23,7 +23,7 @@ public static Response authenticationFailed() { } public static Response contextLengthExceeded(int maxTokens, int estimatedTokens) { - return badRequest().json(createJson(new ErrorBody( + return badRequest().json(createJson(new ErrorEnvelope(new ErrorBody( """ This model's maximum context length is %d tokens. Your request contains approximately %d tokens. @@ -31,7 +31,7 @@ public static Response contextLengthExceeded(int maxTokens, int estimatedTokens) "invalid_request_error", "input", "context_length_exceeded" - ))).build(); + )))).build(); } public static Response tokenLimitExceeded() { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java index 925c38ca1a..4957d73a34 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java @@ -5,7 +5,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.concurrent.GuardedBy; import java.time.Instant; +import java.util.concurrent.atomic.AtomicLong; import static java.time.Instant.now; @@ -16,25 +18,31 @@ public class AiApiLimit { private int maxTokens; private int period; + + private final Object lock = new Object(); + + @GuardedBy("lock") private Instant nextReset; - private long tokens; - public AiApiLimit() { - nextReset = now().plusSeconds(period); - } + private AtomicLong tokens = new AtomicLong(0); public long checkLimit() { - if (now().isAfter(nextReset)) { - tokens = 0; - nextReset = now().plusSeconds(period); - log.debug("Resetting AI API usage limit."); + Instant now = now(); + + if (now.isAfter(nextReset)) { + synchronized (lock) { + tokens.set(0); + nextReset = now.plusSeconds(period); + log.debug("Resetting AI API usage limit."); + } } - return maxTokens - tokens; + + return maxTokens - tokens.get(); } public void addTokens(long tokens) { log.debug("Adding {} tokens to AI API usage limit.", tokens); - this.tokens += tokens; + this.tokens.addAndGet(tokens); } public int getMaxTokens() { @@ -53,5 +61,6 @@ public int getPeriod() { @MCAttribute public void setPeriod(int period) { this.period = period; + nextReset = now().plusSeconds(period); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java index 541ed27d10..d702f0f15c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java @@ -12,7 +12,7 @@ public class JDBCAiApiUsageStore extends AbstractJdbcSupport implements AiApiSto private static final String CREATE_TABLE_SQL = """ CREATE TABLE IF NOT EXISTS ai_api_usage ( - id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, // @TODO GENERATED ALWAYS AS IDENTITY is PostgreSQL specific username VARCHAR(255) NOT NULL, input_tokens INT NOT NULL, output_tokens INT NOT NULL, diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index ccaa19387a..21e59e8945 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -5,6 +5,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -13,7 +14,7 @@ public class SimpleAiApiStore implements AiApiStore { private static final Logger log = LoggerFactory.getLogger(SimpleAiApiStore.class); - private List users; + private List users = new ArrayList<>(); private AiApiLimit limit = new AiApiLimit(); @Override From 8fa0d9571cbaeb17e0b38455f319e8970ee0f1a5 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 7 May 2026 15:59:18 +0200 Subject: [PATCH 04/43] refactor: make `tokens` in `AiApiLimit` final to improve immutability --- .../predic8/membrane/core/interceptor/ai/store/AiApiLimit.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java index 4957d73a34..6450bf2dd7 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java @@ -24,7 +24,7 @@ public class AiApiLimit { @GuardedBy("lock") private Instant nextReset; - private AtomicLong tokens = new AtomicLong(0); + private final AtomicLong tokens = new AtomicLong(0); public long checkLimit() { Instant now = now(); From d44e839f062c6930fb38b01dec6afe6ce5c6c823 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 7 May 2026 18:08:38 +0200 Subject: [PATCH 05/43] feat: modularize AI providers and enhance OpenAI interceptor - Removed `AiUtil` and replaced with modular `AiProvider` interface. - Added provider implementations: `Claude`, `OpenAI`. - Updated `OpenAIAPIInterceptor` to use configurable providers and enforce model restrictions. - Introduced `NoAiApiLimit` for simplified limit management. - Enhanced error handling with model validation in `OpenAiApiUtil`. --- .../membrane/core/interceptor/ai/AiUtil.java | 55 ------------------- .../interceptor/ai/OpenAIAPIInterceptor.java | 53 +++++++++++++----- .../core/interceptor/ai/OpenAiApiUtil.java | 16 +++++- .../interceptor/ai/provider/AiProvider.java | 13 +++++ .../core/interceptor/ai/provider/Claude.java | 49 +++++++++++++++++ .../core/interceptor/ai/provider/OpenAI.java | 44 +++++++++++++++ .../core/interceptor/ai/store/AiApiLimit.java | 2 +- .../core/interceptor/ai/store/AiApiStore.java | 4 ++ .../interceptor/ai/store/NoAiApiLimit.java | 11 ++++ .../ai/store/SimpleAiApiStore.java | 2 +- 10 files changed, 178 insertions(+), 71 deletions(-) delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java deleted file mode 100644 index 5c134d0ae9..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiUtil.java +++ /dev/null @@ -1,55 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai; - -import com.predic8.membrane.core.http.Header; - -public class AiUtil { - - public static final String BEARER_PREFIX = "Bearer"; - - private AiUtil() {} - - /** - * Estimates the number of tokens in a given text. - * The calculation assumes an average token length of 4 characters. - *

- * Content Approximation - * English prose chars / 4 - * German/French chars / 3.5 - * JSON/XML/code chars / 2.5–3 - * Chinese/Japanese very different - *

- * For API gateways, quotas, billing alerts, or rate limiting, approximate counting is often sufficient. - *

- * @param text the input string whose tokens are to be estimated - * @return the estimated number of tokens, rounded up to the nearest integer - */ - public static int estimateTokens(String text) { - return (int) Math.ceil(text.length() / 4.0); - } - - /** - * Extracts the Bearer token from the Authorization header. - * If the Authorization header is null or does not contain - * a Bearer token, this method returns null. - * - * @param header the Header object from which the Authorization - * header is to be extracted - * @return the Bearer token as a String if present; otherwise null - */ - public static String extractBearerToken(Header header) { - var ah = header.getAuthorization(); - if (ah == null) { - return null; - } - - int index = ah.indexOf(BEARER_PREFIX); - if (index < 0) { - return null; - } - - var token = ah.substring(index + BEARER_PREFIX.length()).trim(); - - return token.isEmpty() ? null : token; - } - -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java index 073822a536..e13d84bb3e 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java @@ -9,23 +9,22 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.AbstractInterceptor; import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.interceptor.ai.provider.AiProvider; import com.predic8.membrane.core.interceptor.ai.store.AiApiStore; import com.predic8.membrane.core.interceptor.ai.store.Usage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.List; -import static com.predic8.membrane.core.http.Header.AUTHORIZATION; import static com.predic8.membrane.core.interceptor.Interceptor.Flow.REQUEST; import static com.predic8.membrane.core.interceptor.Interceptor.Flow.RESPONSE; import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; import static com.predic8.membrane.core.interceptor.Outcome.RETURN; -import static com.predic8.membrane.core.interceptor.ai.AiUtil.estimateTokens; -import static com.predic8.membrane.core.interceptor.ai.AiUtil.extractBearerToken; import static com.predic8.membrane.core.interceptor.ai.OpenAiApiUtil.*; -@MCElement(name = "openAI") +@MCElement(name = "aiGateway") public class OpenAIAPIInterceptor extends AbstractInterceptor { private static final Logger log = LoggerFactory.getLogger(OpenAIAPIInterceptor.class); @@ -35,9 +34,13 @@ public class OpenAIAPIInterceptor extends AbstractInterceptor { private static final ObjectMapper om = new ObjectMapper(); + private AiProvider provider; + private String apiKey; private int maxOutputTokens; private int maxInputTokens; + private List models; + private AiApiStore store; @Override @@ -51,7 +54,7 @@ public Outcome handleRequest(Exchange exc) { var header = exc.getRequest().getHeader(); if (store != null) { - var opt = store.getUser(extractBearerToken(header)); + var opt = store.getUser(provider.getApiKey(header)); if (opt.isEmpty()) { exc.setResponse(authenticationFailed()); return RETURN; @@ -66,8 +69,7 @@ public Outcome handleRequest(Exchange exc) { exc.setProperty(MEMBRANE_AI_USERTOKEN, user); } - header.removeFields(AUTHORIZATION); - header.add(AUTHORIZATION, "Bearer " + apiKey); + provider.setApiKey(header,apiKey); var json = getJson(exc, REQUEST); @@ -78,18 +80,27 @@ public Outcome handleRequest(Exchange exc) { if (maxInputTokens != 0) { var input = json.get("input"); if (input != null) { - var estimated = estimateTokens(input.asText()); + var estimated = provider.estimateInputTokens(json); if (estimated > maxInputTokens) { exc.setResponse(contextLengthExceeded(maxInputTokens, estimated)); return RETURN; } } } - setJsonResponse(exc, json); + + if (models != null) { + var model = json.path("model").asText(); + if (!models.contains(model)) { + exc.setResponse(modelNotAllowed(model, models)); + return RETURN; + } + } + + setJsonRequest(exc, json); return CONTINUE; } - private static void setJsonResponse(Exchange exc, ObjectNode json) { + private static void setJsonRequest(Exchange exc, ObjectNode json) { try { exc.getRequest().setBodyContent(om.writeValueAsBytes(json)); } catch (JsonProcessingException e) { @@ -136,8 +147,6 @@ private Usage getUsage(ObjectNode json) { return new Usage(usage.path("input_tokens").asInt(), usage.path("output_tokens").asInt(), usage.path("total_tokens").asInt()); } - - public String getApiKey() { return apiKey; } @@ -151,7 +160,7 @@ public AiApiStore getAiStore() { return store; } - @MCChildElement(allowForeign = true) + @MCChildElement(allowForeign = true, order = 10) public void setAiStore(AiApiStore store) { this.store = store; } @@ -178,4 +187,22 @@ public int getMaxInputTokens() { public void setMaxInputTokens(int maxInputTokens) { this.maxInputTokens = maxInputTokens; } + + public List getModels() { + return models; + } + + @MCAttribute + public void setModels(List models) { + this.models = models; + } + + public AiProvider getProvider() { + return provider; + } + + @MCChildElement(allowForeign = true, order = 0) + public void setProvider(AiProvider provider) { + this.provider = provider; + } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java index 5d63de213f..26df897f39 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java @@ -3,6 +3,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.predic8.membrane.core.http.Response; +import java.util.Collection; + import static com.predic8.membrane.core.http.Header.WWW_AUTHENTICATE; import static com.predic8.membrane.core.http.Response.badRequest; import static com.predic8.membrane.core.http.Response.unauthorized; @@ -11,6 +13,18 @@ public class OpenAiApiUtil { private static final ObjectMapper om = new ObjectMapper(); + public static Response modelNotAllowed(String model, Collection allowedModels) { + return badRequest().json(createJson(new ErrorEnvelope( + new ErrorBody( + "Model '%s' is not allowed. Allowed models: %s." + .formatted(model, String.join(", ", allowedModels)), + "invalid_request_error", + null, + "model_not_allowed" + ) + ))).build(); + } + public static Response authenticationFailed() { return unauthorized().header(WWW_AUTHENTICATE, "Bearer").json(createJson(new ErrorEnvelope( new ErrorBody( @@ -22,7 +36,7 @@ public static Response authenticationFailed() { ))).build(); } - public static Response contextLengthExceeded(int maxTokens, int estimatedTokens) { + public static Response contextLengthExceeded(long maxTokens, long estimatedTokens) { return badRequest().json(createJson(new ErrorEnvelope(new ErrorBody( """ This model's maximum context length is %d tokens. diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java new file mode 100644 index 0000000000..6a8b197cc2 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java @@ -0,0 +1,13 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.fasterxml.jackson.databind.JsonNode; +import com.predic8.membrane.core.http.Header; + +public interface AiProvider { + + void setApiKey(Header header, String apiKey); + + long estimateInputTokens(JsonNode json); + + String getApiKey(Header header); +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java new file mode 100644 index 0000000000..aaf6891ce1 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java @@ -0,0 +1,49 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.fasterxml.jackson.databind.JsonNode; +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.http.Header; + +@MCElement( name="claude") +public class Claude implements AiProvider{ + + public static final String X_API_KEY = "x-api-key"; + + @Override + public void setApiKey(Header header, String apiKey) { + header.removeFields(X_API_KEY); + header.add(X_API_KEY, apiKey); + } + + @Override + public long estimateInputTokens(JsonNode request) { + int tokens = 0; + + // System prompt + tokens += request.path("system").asText().length() / 4; + + // Messages + for (JsonNode message : request.path("messages")) { + JsonNode content = message.path("content"); + if (content.isTextual()) { + tokens += content.asText().length() / 4; + } else if (content.isArray()) { + for (JsonNode block : content) { + String type = block.path("type").asText(); + if (type.equals("text")) { + tokens += block.path("text").asText().length() / 4; + } else if (type.equals("image")) { + tokens += 1000; + } + } + } + } + + return tokens; + } + + @Override + public String getApiKey(Header header) { + return header.getFirstValue(X_API_KEY); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java new file mode 100644 index 0000000000..43cd51dbf5 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java @@ -0,0 +1,44 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.fasterxml.jackson.databind.JsonNode; +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.http.Header; + +import static com.predic8.membrane.core.http.Header.AUTHORIZATION; + +@MCElement( name="openai") +public class OpenAI implements AiProvider{ + + public static final String BEARER_PREFIX = "Bearer"; + + @Override + public void setApiKey(Header header, String apiKey) { + header.removeFields(AUTHORIZATION); + header.add(AUTHORIZATION, "Bearer " + apiKey); + } + + @Override + public long estimateInputTokens(JsonNode json) { + var input = json.path("input").asText(); + if (input == null) + return 0; + return (long) Math.ceil(input.length() / 4.0); + } + + @Override + public String getApiKey(Header header) { + var ah = header.getAuthorization(); + if (ah == null) { + return null; + } + + int index = ah.indexOf(BEARER_PREFIX); + if (index < 0) { + return null; + } + + var token = ah.substring(index + BEARER_PREFIX.length()).trim(); + + return token.isEmpty() ? null : token; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java index 6450bf2dd7..9ef5982683 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java @@ -29,7 +29,7 @@ public class AiApiLimit { public long checkLimit() { Instant now = now(); - if (now.isAfter(nextReset)) { + if (nextReset == null || now.isAfter(nextReset)) { synchronized (lock) { tokens.set(0); nextReset = now.plusSeconds(period); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java index 41eafc6de5..36633e3a15 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java @@ -4,6 +4,10 @@ import java.util.Optional; +/** + * @TODO + * - Store .status, .error, .model, .stop_reason + */ public interface AiApiStore { default void init(Router router) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java new file mode 100644 index 0000000000..498abc57f6 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java @@ -0,0 +1,11 @@ +package com.predic8.membrane.core.interceptor.ai.store; + +public class NoAiApiLimit extends AiApiLimit{ + + @Override + public long checkLimit() { + return 1000; + } + + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index 21e59e8945..4045ca8c81 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -15,7 +15,7 @@ public class SimpleAiApiStore implements AiApiStore { private static final Logger log = LoggerFactory.getLogger(SimpleAiApiStore.class); private List users = new ArrayList<>(); - private AiApiLimit limit = new AiApiLimit(); + private AiApiLimit limit = new NoAiApiLimit(); @Override public void store(String user, Usage usage) { From 863e5cc273a2a93b8b9a40900ff0ba6c8b4d3717 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 8 May 2026 10:48:09 +0200 Subject: [PATCH 06/43] feat: implement modular AI provider framework with request/response abstraction - Added `AiApiRequest` and `AiApiResponse` abstractions for request/response handling. - Introduced `AbstractAiApiRequest` and `AbstractAiApiResponse` as base classes. - Implemented providers: `Google`, `OpenAI`, and `Claude` with concrete request/response handling. - Updated `AiProvider` to handle request/response creation. - Refactored `OpenAIAPIInterceptor` to leverage request/response abstraction and enforce contract restrictions. - Enhanced `JsonUtil` with helper methods for JSON body parsing and updates. - Updated `AiApiStore` and related classes for improved usage tracking and user abstraction. --- .../interceptor/ai/AbstractAiApiRequest.java | 59 +++++++++ .../interceptor/ai/AbstractAiApiResponse.java | 49 ++++++++ .../core/interceptor/ai/AiApiRequest.java | 19 +++ .../core/interceptor/ai/AiApiResponse.java | 10 ++ .../interceptor/ai/OpenAIAPIInterceptor.java | 87 ++++--------- .../interceptor/ai/provider/AiProvider.java | 11 +- .../core/interceptor/ai/provider/Claude.java | 45 ++----- .../ai/provider/ClaudeAiRequest.java | 48 +++++++ .../ai/provider/ClaudeAiResponse.java | 12 ++ .../core/interceptor/ai/provider/Google.java | 20 +++ .../ai/provider/GoogleAiRequest.java | 119 ++++++++++++++++++ .../ai/provider/GoogleAiResponse.java | 27 ++++ .../core/interceptor/ai/provider/OpenAI.java | 38 ++---- .../ai/provider/OpenAiAiRequest.java | 18 +++ .../ai/provider/OpenAiAiResponse.java | 12 ++ .../core/interceptor/ai/store/AiApiStore.java | 2 +- .../core/interceptor/ai/store/AiApiUser.java | 5 + .../ai/store/JDBCAiApiUsageStore.java | 4 +- .../ai/store/SimpleAiApiStore.java | 7 +- .../membrane/core/util/json/JsonUtil.java | 38 +++++- 20 files changed, 482 insertions(+), 148 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiResponse.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiResponse.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Google.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiResponse.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiResponse.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java new file mode 100644 index 0000000000..66ae69bc98 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java @@ -0,0 +1,59 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.util.json.JsonUtil; + +import static com.predic8.membrane.core.http.Header.AUTHORIZATION; + +public abstract class AbstractAiApiRequest implements AiApiRequest { + + public static final String BEARER_PREFIX = "Bearer"; + private static final String MAX_OUTPUT_TOKENS = "max_output_tokens"; + + protected final Exchange exchange; + protected ObjectNode json; + + public AbstractAiApiRequest(Exchange exchange) { + this.exchange = exchange; + json = JsonUtil.getJsonObject(exchange.getRequest()); + } + + @Override + public void setApiKey(String apiKey) { + exchange.getRequest().getHeader().removeFields(AUTHORIZATION); + exchange.getRequest().getHeader().add(AUTHORIZATION, "Bearer " + apiKey); + } + + @Override + public void setMaxOutputTokens(int maxOutputTokens) { + json.put(MAX_OUTPUT_TOKENS, maxOutputTokens); + } + + @Override + public String getApiKey() { + var ah = exchange.getRequest().getHeader().getAuthorization(); + if (ah == null) { + return null; + } + + int index = ah.indexOf(BEARER_PREFIX); + if (index < 0) { + return null; + } + + var token = ah.substring(index + BEARER_PREFIX.length()).trim(); + + return token.isEmpty() ? null : token; + } + + @Override + public ObjectNode getJson() { + return json; + } + + @Override + public String getModel() { + return json.path("model").asText(); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java new file mode 100644 index 0000000000..38ae473cc3 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java @@ -0,0 +1,49 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.util.json.JsonUtil; + +public abstract class AbstractAiApiResponse implements AiApiResponse { + + protected final Exchange exchange; + protected ObjectNode json; + + public AbstractAiApiResponse(Exchange exchange) { + this.exchange = exchange; + json = JsonUtil.getJsonObject(exchange.getResponse()); + } + + @Override + public boolean isError() { + return json.get("error") != null && !json.get("error").isNull(); + } + + public Usage getUsage() { + + var usage = json.path("usage"); + + int inputTokens = getInputTokens(usage); + int outputTokens = getOutputTokens(usage); + int totalTokens = usage.path("total_tokens").asInt(inputTokens + outputTokens); + + return new Usage( + inputTokens, + outputTokens, + totalTokens + ); + } + + private static int getOutputTokens(JsonNode usage) { + return usage.path("output_tokens").asInt( + usage.path("completion_tokens").asInt(0) + ); + } + + private static int getInputTokens(JsonNode usage) { + return usage.path("input_tokens").asInt( + usage.path("prompt_tokens").asInt(0)); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiRequest.java new file mode 100644 index 0000000000..cec5f30079 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiRequest.java @@ -0,0 +1,19 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.databind.node.ObjectNode; + +public interface AiApiRequest { + + String getModel(); + + String getApiKey(); + + void setApiKey(String apiKey); + + void setMaxOutputTokens(int maxOutputTokens); + + long estimateInputTokens(); + + ObjectNode getJson(); + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiResponse.java new file mode 100644 index 0000000000..b01e135626 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiResponse.java @@ -0,0 +1,10 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.predic8.membrane.core.interceptor.ai.store.Usage; + +public interface AiApiResponse { + + boolean isError(); + + Usage getUsage(); +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java index e13d84bb3e..3d7fd0bb2a 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java @@ -1,8 +1,5 @@ package com.predic8.membrane.core.interceptor.ai; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.annot.MCAttribute; import com.predic8.membrane.annot.MCChildElement; import com.predic8.membrane.annot.MCElement; @@ -11,28 +8,23 @@ import com.predic8.membrane.core.interceptor.Outcome; import com.predic8.membrane.core.interceptor.ai.provider.AiProvider; import com.predic8.membrane.core.interceptor.ai.store.AiApiStore; -import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.interceptor.ai.store.AiApiUser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; import java.util.List; -import static com.predic8.membrane.core.interceptor.Interceptor.Flow.REQUEST; -import static com.predic8.membrane.core.interceptor.Interceptor.Flow.RESPONSE; import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; import static com.predic8.membrane.core.interceptor.Outcome.RETURN; import static com.predic8.membrane.core.interceptor.ai.OpenAiApiUtil.*; +import static com.predic8.membrane.core.util.json.JsonUtil.setJsonBody; @MCElement(name = "aiGateway") public class OpenAIAPIInterceptor extends AbstractInterceptor { private static final Logger log = LoggerFactory.getLogger(OpenAIAPIInterceptor.class); - public static final String MEMBRANE_AI_USERTOKEN = "membrane.ai.usertoken"; - public static final String MAX_OUTPUT_TOKENS = "max_output_tokens"; - - private static final ObjectMapper om = new ObjectMapper(); + public static final String MEMBRANE_AI_USER = "membrane.ai.user"; private AiProvider provider; @@ -51,10 +43,11 @@ public void init() { @Override public Outcome handleRequest(Exchange exc) { - var header = exc.getRequest().getHeader(); + + var aiReq = provider.getAiApiRequest(exc); if (store != null) { - var opt = store.getUser(provider.getApiKey(header)); + var opt = store.getUser(aiReq.getApiKey()); if (opt.isEmpty()) { exc.setResponse(authenticationFailed()); return RETURN; @@ -66,87 +59,49 @@ public Outcome handleRequest(Exchange exc) { exc.setResponse(tokenLimitExceeded()); return RETURN; } - exc.setProperty(MEMBRANE_AI_USERTOKEN, user); + exc.setProperty(MEMBRANE_AI_USER, user); } - provider.setApiKey(header,apiKey); - - var json = getJson(exc, REQUEST); + aiReq.setApiKey(apiKey); if (maxOutputTokens != 0) { - json.put(MAX_OUTPUT_TOKENS, maxOutputTokens); + aiReq.setMaxOutputTokens(maxOutputTokens); } if (maxInputTokens != 0) { - var input = json.get("input"); - if (input != null) { - var estimated = provider.estimateInputTokens(json); - if (estimated > maxInputTokens) { - exc.setResponse(contextLengthExceeded(maxInputTokens, estimated)); - return RETURN; - } + var estimated = aiReq.estimateInputTokens(); + if (estimated > maxInputTokens) { + exc.setResponse(contextLengthExceeded(maxInputTokens, estimated)); + return RETURN; } } if (models != null) { - var model = json.path("model").asText(); + var model = aiReq.getModel(); if (!models.contains(model)) { exc.setResponse(modelNotAllowed(model, models)); return RETURN; } } - setJsonRequest(exc, json); + setJsonBody(exc.getRequest(), aiReq.getJson()); return CONTINUE; } - private static void setJsonRequest(Exchange exc, ObjectNode json) { - try { - exc.getRequest().setBodyContent(om.writeValueAsBytes(json)); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - private static ObjectNode getJson(Exchange exc, Flow flow) { - try { - if (om.readTree(exc.getMessage(flow).getBodyAsStreamDecoded()) instanceof ObjectNode on) { - return on; - } - throw new RuntimeException("Expected JSON Object"); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - @Override public Outcome handleResponse(Exchange exc) { - var response = exc.getResponse(); - if (!response.isJSON()) { - log.debug("Response is not JSON"); - return CONTINUE; - } - var json = getJson(exc, RESPONSE); + var aiRes = provider.getAiApiResponse(exc); + + if (aiRes.isError()) + return CONTINUE; // pass error from AI API to client - // Pass error from AI API to client - if (json.get("error") != null && !json.get("error").isNull()) { - return CONTINUE; - } if (store != null) { - store.store(exc.getProperty(MEMBRANE_AI_USERTOKEN, String.class), getUsage(json)); + store.store(exc.getProperty(MEMBRANE_AI_USER, AiApiUser.class), aiRes.getUsage()); } return CONTINUE; } - private Usage getUsage(ObjectNode json) { - var usage = json.path("usage"); - if (usage.isNull()) { - return new Usage(0, 0, 0); - } - return new Usage(usage.path("input_tokens").asInt(), usage.path("output_tokens").asInt(), usage.path("total_tokens").asInt()); - } - public String getApiKey() { return apiKey; } @@ -201,7 +156,7 @@ public AiProvider getProvider() { return provider; } - @MCChildElement(allowForeign = true, order = 0) + @MCChildElement(allowForeign = true) public void setProvider(AiProvider provider) { this.provider = provider; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java index 6a8b197cc2..4b99f8f46d 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java @@ -1,13 +1,12 @@ package com.predic8.membrane.core.interceptor.ai.provider; -import com.fasterxml.jackson.databind.JsonNode; -import com.predic8.membrane.core.http.Header; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AiApiRequest; +import com.predic8.membrane.core.interceptor.ai.AiApiResponse; public interface AiProvider { - void setApiKey(Header header, String apiKey); + AiApiRequest getAiApiRequest(Exchange request); + AiApiResponse getAiApiResponse(Exchange request); - long estimateInputTokens(JsonNode json); - - String getApiKey(Header header); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java index aaf6891ce1..93b60ba26f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java @@ -1,49 +1,20 @@ package com.predic8.membrane.core.interceptor.ai.provider; -import com.fasterxml.jackson.databind.JsonNode; import com.predic8.membrane.annot.MCElement; -import com.predic8.membrane.core.http.Header; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AiApiRequest; +import com.predic8.membrane.core.interceptor.ai.AiApiResponse; @MCElement( name="claude") -public class Claude implements AiProvider{ - - public static final String X_API_KEY = "x-api-key"; - - @Override - public void setApiKey(Header header, String apiKey) { - header.removeFields(X_API_KEY); - header.add(X_API_KEY, apiKey); - } +public class Claude implements AiProvider { @Override - public long estimateInputTokens(JsonNode request) { - int tokens = 0; - - // System prompt - tokens += request.path("system").asText().length() / 4; - - // Messages - for (JsonNode message : request.path("messages")) { - JsonNode content = message.path("content"); - if (content.isTextual()) { - tokens += content.asText().length() / 4; - } else if (content.isArray()) { - for (JsonNode block : content) { - String type = block.path("type").asText(); - if (type.equals("text")) { - tokens += block.path("text").asText().length() / 4; - } else if (type.equals("image")) { - tokens += 1000; - } - } - } - } - - return tokens; + public AiApiRequest getAiApiRequest(Exchange exchange) { + return new ClaudeAiRequest(exchange); } @Override - public String getApiKey(Header header) { - return header.getFirstValue(X_API_KEY); + public AiApiResponse getAiApiResponse(Exchange request) { + return new ClaudeAiResponse(request); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java new file mode 100644 index 0000000000..087c2ec55e --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java @@ -0,0 +1,48 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractAiApiRequest; + +public class ClaudeAiRequest extends AbstractAiApiRequest { + + public static final String X_API_KEY = "x-api-key"; + + public ClaudeAiRequest(Exchange exchange) { + super(exchange); + } + + @Override + public long estimateInputTokens() { + // System prompt + int tokens = json.path("system").asText().length() / 4; + + // Messages + for (var message : json.path("messages")) { + var content = message.path("content"); + if (content.isTextual()) { + tokens += content.asText().length() / 4; + } else if (content.isArray()) { + for (var block : content) { + var type = block.path("type").asText(); + if (type.equals("text")) { + tokens += block.path("text").asText().length() / 4; + } else if (type.equals("image")) { + tokens += 1000; + } + } + } + } + return tokens; + } + + @Override + public String getApiKey() { + return exchange.getRequest().getHeader().getFirstValue(X_API_KEY); + } + + @Override + public void setApiKey(String apiKey) { + exchange.getRequest().getHeader().removeFields(X_API_KEY); + exchange.getRequest().getHeader().add(X_API_KEY, apiKey); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiResponse.java new file mode 100644 index 0000000000..a500e5bfdf --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiResponse.java @@ -0,0 +1,12 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractAiApiResponse; + +public class ClaudeAiResponse extends AbstractAiApiResponse { + + public ClaudeAiResponse(Exchange exchange) { + super(exchange); + } + +} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Google.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Google.java new file mode 100644 index 0000000000..b87ed46c11 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Google.java @@ -0,0 +1,20 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AiApiRequest; +import com.predic8.membrane.core.interceptor.ai.AiApiResponse; + +@MCElement( name="google",id = "google-ai-provider") +public class Google implements AiProvider { + + @Override + public AiApiRequest getAiApiRequest(Exchange exchange) { + return new GoogleAiRequest(exchange); + } + + @Override + public AiApiResponse getAiApiResponse(Exchange request) { + return new GoogleAiResponse(request); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java new file mode 100644 index 0000000000..a1af123af7 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java @@ -0,0 +1,119 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractAiApiRequest; + +public class GoogleAiRequest extends AbstractAiApiRequest { + + public static final String X_GOOG_API_KEY = "x-goog-api-key"; + + public GoogleAiRequest(Exchange exchange) { + super(exchange); + } + + @Override + public String getModel() { + + var uri = exchange.getRequest().getUri(); + + if (uri == null) { + return null; + } + + // Example: + // /v1beta/models/gemini-2.5-pro:generateContent + int modelsIndex = uri.indexOf("/models/"); + if (modelsIndex < 0) { + return null; + } + + var modelPart = uri.substring(modelsIndex + "/models/".length()); + + int colonIndex = modelPart.indexOf(':'); + if (colonIndex >= 0) { + return modelPart.substring(0, colonIndex); + } + + return modelPart; + } + + @Override + public String getApiKey() { + return exchange.getRequest().getHeader().getFirstValue(X_GOOG_API_KEY); + } + + @Override + public void setApiKey(String apiKey) { + exchange.getRequest().getHeader().removeFields(X_GOOG_API_KEY); + exchange.getRequest().getHeader().add(X_GOOG_API_KEY, apiKey); + } + + public long estimateInputTokens() { + if (json == null || json.isNull()) { + return 0; + } + + long chars = countText(json.path("systemInstruction")); + + var contents = json.path("contents"); + if (contents.isArray()) { + for (JsonNode content : contents) { + chars += countText(content.path("parts")); + } + } + + // Safety margin for JSON structure, roles, metadata, etc. + return Math.max(1, Math.round(chars / 4.0 * 1.15)); + } + + private long countText(JsonNode node) { + if (node == null || node.isMissingNode() || node.isNull()) { + return 0; + } + + if (node.isTextual()) { + return node.asText().length(); + } + + if (node.isObject()) { + long chars = 0; + + JsonNode text = node.get("text"); + if (text != null && text.isTextual()) { + chars += text.asText().length(); + } + + JsonNode parts = node.get("parts"); + if (parts != null) { + chars += countText(parts); + } + + return chars; + } + + if (node.isArray()) { + long chars = 0; + for (JsonNode child : node) { + chars += countText(child); + } + return chars; + } + + return 0; + } + + @Override + public void setMaxOutputTokens(int maxOutputTokens) { + getGenerationConfig().put("maxOutputTokens", maxOutputTokens); + } + + private ObjectNode getGenerationConfig() { + var gc = json.get("generationConfig"); + if (gc instanceof ObjectNode objectNode) { + return objectNode; + } + return json.putObject("generationConfig"); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiResponse.java new file mode 100644 index 0000000000..66add9a43e --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiResponse.java @@ -0,0 +1,27 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractAiApiResponse; +import com.predic8.membrane.core.interceptor.ai.store.Usage; + +public class GoogleAiResponse extends AbstractAiApiResponse { + + public GoogleAiResponse(Exchange exchange) { + super(exchange); + } + + @Override + public Usage getUsage() { + var usage = json.path("usageMetadata"); + + int inputTokens = usage.path("promptTokenCount").asInt(0); + int outputTokens = usage.path("candidatesTokenCount").asInt(0); + int totalTokens = usage.path("totalTokenCount").asInt(inputTokens + outputTokens); + + return new Usage( + inputTokens, + outputTokens, + totalTokens + ); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java index 43cd51dbf5..91b0ee8889 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java @@ -1,44 +1,20 @@ package com.predic8.membrane.core.interceptor.ai.provider; -import com.fasterxml.jackson.databind.JsonNode; import com.predic8.membrane.annot.MCElement; -import com.predic8.membrane.core.http.Header; - -import static com.predic8.membrane.core.http.Header.AUTHORIZATION; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AiApiRequest; +import com.predic8.membrane.core.interceptor.ai.AiApiResponse; @MCElement( name="openai") public class OpenAI implements AiProvider{ - public static final String BEARER_PREFIX = "Bearer"; - - @Override - public void setApiKey(Header header, String apiKey) { - header.removeFields(AUTHORIZATION); - header.add(AUTHORIZATION, "Bearer " + apiKey); - } - @Override - public long estimateInputTokens(JsonNode json) { - var input = json.path("input").asText(); - if (input == null) - return 0; - return (long) Math.ceil(input.length() / 4.0); + public AiApiRequest getAiApiRequest(Exchange exchange) { + return new OpenAiAiRequest(exchange); } @Override - public String getApiKey(Header header) { - var ah = header.getAuthorization(); - if (ah == null) { - return null; - } - - int index = ah.indexOf(BEARER_PREFIX); - if (index < 0) { - return null; - } - - var token = ah.substring(index + BEARER_PREFIX.length()).trim(); - - return token.isEmpty() ? null : token; + public AiApiResponse getAiApiResponse(Exchange request) { + return new OpenAiAiResponse(request); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java new file mode 100644 index 0000000000..43403e90dc --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java @@ -0,0 +1,18 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractAiApiRequest; + +import static java.lang.Math.ceil; + +public class OpenAiAiRequest extends AbstractAiApiRequest { + + public OpenAiAiRequest(Exchange exchange) { + super(exchange); + } + + @Override + public long estimateInputTokens() { + return (long) ceil(json.path("input").asText("").length() / 4.0); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiResponse.java new file mode 100644 index 0000000000..a20575cf68 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiResponse.java @@ -0,0 +1,12 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractAiApiResponse; + +public class OpenAiAiResponse extends AbstractAiApiResponse { + + public OpenAiAiResponse(Exchange exchange) { + super(exchange); + } + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java index 36633e3a15..3f205255ff 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java @@ -13,7 +13,7 @@ public interface AiApiStore { default void init(Router router) { } - void store(String user, Usage usage); + void store(AiApiUser user, Usage usage); Optional getUser(String token); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java index 61125c1d05..a867bfddcc 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java @@ -26,4 +26,9 @@ public String getToken() { public void setToken(String token) { this.token = token; } + + @Override + public String toString() { + return "user(name: %s)".formatted(name); + } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java index d702f0f15c..05b6f729bb 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java @@ -37,9 +37,9 @@ public void init(Router router) { } @Override - public void store(String user, com.predic8.membrane.core.interceptor.ai.store.Usage usage) { + public void store(AiApiUser user, com.predic8.membrane.core.interceptor.ai.store.Usage usage) { try (var connection = getConnection(); var ps = connection.prepareStatement(INSERT_SQL)) { - ps.setString(1, user); + ps.setString(1, user.getName()); ps.setInt(2, usage.inputTokens()); ps.setInt(3, usage.outputTokens()); ps.setInt(4, usage.totalTokens()); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index 4045ca8c81..28f0f6db0e 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -18,8 +18,8 @@ public class SimpleAiApiStore implements AiApiStore { private AiApiLimit limit = new NoAiApiLimit(); @Override - public void store(String user, Usage usage) { - log.info("User: {} Usage: {}", user, usage); + public void store(AiApiUser user, Usage usage) { + log.info("User: {} Usage: {}", user.getName(), usage); limit.addTokens(usage.totalTokens()); } @@ -30,6 +30,9 @@ public Optional getUser(String token) { @Override public long checkLimit(AiApiUser user) { + if (user == null) + return 0; // anonymous user gets no tokens + return limit.checkLimit(); } diff --git a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java index 5fb73f8092..e93a77d3c4 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java @@ -14,13 +14,23 @@ package com.predic8.membrane.core.util.json; -import com.fasterxml.jackson.databind.*; -import com.fasterxml.jackson.databind.node.*; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.core.http.Message; -import java.math.*; +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; + +import static com.predic8.membrane.core.http.MimeType.APPLICATION_JSON; public class JsonUtil { + private static final ObjectMapper om = new ObjectMapper(); + private static final JsonNodeFactory FACTORY = JsonNodeFactory.instance; /** @@ -75,4 +85,26 @@ public static JsonNode scalarAsJson(String value) { return FACTORY.textNode(value); } + + public static ObjectNode getJsonObject(Message msg) { + try { + if (om.readTree(msg.getBodyAsStreamDecoded()) instanceof ObjectNode on) { + return on; + } + throw new RuntimeException("Expected JSON Object"); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static void setJsonBody(Message msg, ObjectNode json) { + try { + if (!msg.isJSON()) { + msg.getHeader().setContentType(APPLICATION_JSON); + } + msg.setBodyContent(om.writeValueAsBytes(json)); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } } From 2933dad2a3849ca6474695391ec11d4cf4089753 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 8 May 2026 11:27:23 +0200 Subject: [PATCH 07/43] feat: enhance token limit management and error handling across AI components - Updated `checkLimit` method to consider input and output tokens. - Improved token calculation logic in AI request providers. - Enhanced JSON parsing in `JsonUtil` with Optional for safer operations. - Added detailed error handling in `OpenAIAPIInterceptor` for invalid requests. - Refined token estimation logic with safety margins and JSON structure considerations. --- .../interceptor/ai/AbstractAiApiRequest.java | 2 +- .../interceptor/ai/AbstractAiApiResponse.java | 3 +- .../interceptor/ai/OpenAIAPIInterceptor.java | 21 ++++-- .../ai/provider/ClaudeAiRequest.java | 2 +- .../ai/provider/OpenAiAiRequest.java | 75 ++++++++++++++++++- .../core/interceptor/ai/store/AiApiLimit.java | 11 ++- .../core/interceptor/ai/store/AiApiStore.java | 7 +- .../ai/store/JDBCAiApiUsageStore.java | 2 +- .../interceptor/ai/store/NoAiApiLimit.java | 6 +- .../ai/store/SimpleAiApiStore.java | 4 +- .../membrane/core/util/json/JsonUtil.java | 19 +++-- 11 files changed, 126 insertions(+), 26 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java index 66ae69bc98..6f6b1e9193 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java @@ -16,7 +16,7 @@ public abstract class AbstractAiApiRequest implements AiApiRequest { public AbstractAiApiRequest(Exchange exchange) { this.exchange = exchange; - json = JsonUtil.getJsonObject(exchange.getRequest()); + json = JsonUtil.getJsonObject(exchange.getRequest()).orElseThrow(() -> new RuntimeException("No JSON object request.")); } @Override diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java index 38ae473cc3..7ba2ed49dd 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java @@ -1,6 +1,7 @@ package com.predic8.membrane.core.interceptor.ai; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.ai.store.Usage; @@ -13,7 +14,7 @@ public abstract class AbstractAiApiResponse implements AiApiResponse { public AbstractAiApiResponse(Exchange exchange) { this.exchange = exchange; - json = JsonUtil.getJsonObject(exchange.getResponse()); + json = JsonUtil.getJsonObject(exchange.getResponse()).orElse(JsonNodeFactory.instance.objectNode().put("error", "No JSON object response from model.")); } @Override diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java index 3d7fd0bb2a..77af8dc276 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java @@ -14,6 +14,7 @@ import java.util.List; +import static com.predic8.membrane.core.exceptions.ProblemDetails.user; import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; import static com.predic8.membrane.core.interceptor.Outcome.RETURN; import static com.predic8.membrane.core.interceptor.ai.OpenAiApiUtil.*; @@ -44,7 +45,18 @@ public void init() { @Override public Outcome handleRequest(Exchange exc) { - var aiReq = provider.getAiApiRequest(exc); + AiApiRequest aiReq; + try { + aiReq = provider.getAiApiRequest(exc); + } catch (Exception e) { + user(router.getConfiguration().isProduction(),"AI Gateway") + .title("Invalid request") + .detail("Error parsing request: " + e.getMessage()) + .buildAndSetResponse(exc); + return RETURN; + } + + var inputTokens = aiReq.estimateInputTokens(); if (store != null) { var opt = store.getUser(aiReq.getApiKey()); @@ -54,7 +66,7 @@ public Outcome handleRequest(Exchange exc) { } var user = opt.get(); log.debug("User: {}", user); - var remaining = store.checkLimit(user); + var remaining = store.checkLimit(user,inputTokens,maxOutputTokens); if (remaining <= 0) { exc.setResponse(tokenLimitExceeded()); return RETURN; @@ -69,9 +81,8 @@ public Outcome handleRequest(Exchange exc) { } if (maxInputTokens != 0) { - var estimated = aiReq.estimateInputTokens(); - if (estimated > maxInputTokens) { - exc.setResponse(contextLengthExceeded(maxInputTokens, estimated)); + if (inputTokens > maxInputTokens) { + exc.setResponse(contextLengthExceeded(maxInputTokens, inputTokens)); return RETURN; } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java index 087c2ec55e..96a4d39a4f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java @@ -14,7 +14,7 @@ public ClaudeAiRequest(Exchange exchange) { @Override public long estimateInputTokens() { // System prompt - int tokens = json.path("system").asText().length() / 4; + long tokens = json.path("system").asText().length() / 4; // Messages for (var message : json.path("messages")) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java index 43403e90dc..d52d8d06c9 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java @@ -1,10 +1,9 @@ package com.predic8.membrane.core.interceptor.ai.provider; +import com.fasterxml.jackson.databind.JsonNode; import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.ai.AbstractAiApiRequest; -import static java.lang.Math.ceil; - public class OpenAiAiRequest extends AbstractAiApiRequest { public OpenAiAiRequest(Exchange exchange) { @@ -13,6 +12,76 @@ public OpenAiAiRequest(Exchange exchange) { @Override public long estimateInputTokens() { - return (long) ceil(json.path("input").asText("").length() / 4.0); + + long chars = countText(json.path("input")); + + chars += estimateChatCompletitions(); + + // system instructions + chars += countText(json.path("system")); + + // tools/functions contribute significantly + chars += countJsonSize(json.path("tools")); + chars += countJsonSize(json.path("functions")); + + // safety margin for JSON structure and tokenizer variance + return Math.max(1, Math.round(chars / 4.0 * 1.15)); + } + + private long estimateChatCompletitions() { + long chars = 0; + // Chat Completions API + var messages = json.path("messages"); + if (messages.isArray()) { + for (var message : messages) { + chars += countText(message.path("content")); + // roles also consume tokens + chars += message.path("role").asText("").length(); + } + } + return chars; + } + + private long countText(JsonNode node) { + if (node == null || node.isMissingNode() || node.isNull()) { + return 0; + } + + if (node.isTextual()) { + return node.asText().length(); + } + + if (node.isArray()) { + long chars = 0; + for (JsonNode child : node) { + chars += countText(child); + } + return chars; + } + + if (node.isObject()) { + + // OpenAI content blocks: + // { "type": "text", "text": "..." } + long chars = 0; + + var text = node.get("text"); + if (text != null && text.isTextual()) { + chars += text.asText().length(); + } + + chars += countText(node.get("content")); + + return chars; + } + + return 0; + } + + private long countJsonSize(JsonNode node) { + if (node == null || node.isMissingNode() || node.isNull()) { + return 0; + } + return node.toString().length(); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java index 9ef5982683..bb0b135df8 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java @@ -26,7 +26,14 @@ public class AiApiLimit { private final AtomicLong tokens = new AtomicLong(0); - public long checkLimit() { + /** + * Checks if the user has enough tokens to make the request. + * If there aren't enough tokens for the request, 0 or a negative number is returned. + * + * @param tokensForNextRequest + * @return Estimated remaining tokens after this call. + */ + public long checkLimit(long tokensForNextRequest) { Instant now = now(); if (nextReset == null || now.isAfter(nextReset)) { @@ -37,7 +44,7 @@ public long checkLimit() { } } - return maxTokens - tokens.get(); + return maxTokens - tokens.get() - tokensForNextRequest; } public void addTokens(long tokens) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java index 3f205255ff..4f3a4c900c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java @@ -17,6 +17,11 @@ default void init(Router router) { Optional getUser(String token); - long checkLimit(AiApiUser user); + /** + * Checks if the user has enough tokens to make the request. + * @param user + * @return + */ + long checkLimit(AiApiUser user, long inputTokens, long outputTokens); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java index 05b6f729bb..acff4625c1 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java @@ -56,7 +56,7 @@ public Optional getUser(String token) { } @Override - public long checkLimit(AiApiUser user) { + public long checkLimit(AiApiUser user, long inputTokens, long outputTokens) { return 0; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java index 498abc57f6..63cdaa6d40 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java @@ -3,9 +3,7 @@ public class NoAiApiLimit extends AiApiLimit{ @Override - public long checkLimit() { - return 1000; + public long checkLimit(long tokensForNextRequest) { + return 1000; // Returns a value greater than 0 to indicate that the request can be processed. } - - } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index 28f0f6db0e..b63c3bd6d6 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -29,11 +29,11 @@ public Optional getUser(String token) { } @Override - public long checkLimit(AiApiUser user) { + public long checkLimit(AiApiUser user, long inputTokens, long outputTokens) { if (user == null) return 0; // anonymous user gets no tokens - return limit.checkLimit(); + return limit.checkLimit(inputTokens + outputTokens); } @MCChildElement(allowForeign = true,order = 10) diff --git a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java index e93a77d3c4..91177a5e1e 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java @@ -20,15 +20,22 @@ import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.http.Message; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; +import java.util.Optional; import static com.predic8.membrane.core.http.MimeType.APPLICATION_JSON; +import static java.util.Optional.empty; public class JsonUtil { + private static final Logger log = LoggerFactory.getLogger(JsonUtil.class); + + private static final ObjectMapper om = new ObjectMapper(); private static final JsonNodeFactory FACTORY = JsonNodeFactory.instance; @@ -86,15 +93,17 @@ public static JsonNode scalarAsJson(String value) { return FACTORY.textNode(value); } - public static ObjectNode getJsonObject(Message msg) { + public static Optional getJsonObject(Message msg) { try { - if (om.readTree(msg.getBodyAsStreamDecoded()) instanceof ObjectNode on) { - return on; + JsonNode jsonNode = om.readTree(msg.getBodyAsStreamDecoded()); + if (jsonNode instanceof ObjectNode on) { + return Optional.of(on); } - throw new RuntimeException("Expected JSON Object"); + log.debug("Expected JSON Object but got: {}",jsonNode.getNodeType()); } catch (IOException e) { - throw new RuntimeException(e); + log.debug("Error reading JSON: {}", e.getMessage()); } + return empty(); } public static void setJsonBody(Message msg, ObjectNode json) { From 402651daef721281640348d1945f17e8176f710b Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 8 May 2026 13:28:00 +0200 Subject: [PATCH 08/43] feat: improve concurrency and logging for API rate limit management - Synchronized token management methods in `AiApiLimit` to ensure thread safety. - Adjusted log levels for `SimpleAiApiStore` to reduce verbosity. - Added PostgreSQL dependency to the distribution. - Updated logging configuration to set debug level for AI interceptors. --- .../core/interceptor/ai/store/AiApiLimit.java | 19 +++++++++++-------- .../ai/store/SimpleAiApiStore.java | 2 +- .../membrane/core/util/json/JsonUtil.java | 9 ++++----- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java index bb0b135df8..c807931c76 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java @@ -34,10 +34,9 @@ public class AiApiLimit { * @return Estimated remaining tokens after this call. */ public long checkLimit(long tokensForNextRequest) { - Instant now = now(); - - if (nextReset == null || now.isAfter(nextReset)) { - synchronized (lock) { + synchronized (lock) { + Instant now = now(); + if (nextReset == null || now.isAfter(nextReset)) { tokens.set(0); nextReset = now.plusSeconds(period); log.debug("Resetting AI API usage limit."); @@ -48,8 +47,10 @@ public long checkLimit(long tokensForNextRequest) { } public void addTokens(long tokens) { - log.debug("Adding {} tokens to AI API usage limit.", tokens); - this.tokens.addAndGet(tokens); + synchronized (lock) { + log.debug("Adding {} tokens to AI API usage limit.", tokens); + this.tokens.addAndGet(tokens); + } } public int getMaxTokens() { @@ -67,7 +68,9 @@ public int getPeriod() { @MCAttribute public void setPeriod(int period) { - this.period = period; - nextReset = now().plusSeconds(period); + synchronized (lock) { + this.period = period; + nextReset = now().plusSeconds(period); + } } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index b63c3bd6d6..ff99faa98a 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -19,7 +19,7 @@ public class SimpleAiApiStore implements AiApiStore { @Override public void store(AiApiUser user, Usage usage) { - log.info("User: {} Usage: {}", user.getName(), usage); + log.debug("Usage: {}", usage); limit.addTokens(usage.totalTokens()); } diff --git a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java index 91177a5e1e..21b78f7128 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java @@ -23,7 +23,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; import java.util.Optional; @@ -95,12 +94,12 @@ public static JsonNode scalarAsJson(String value) { public static Optional getJsonObject(Message msg) { try { - JsonNode jsonNode = om.readTree(msg.getBodyAsStreamDecoded()); - if (jsonNode instanceof ObjectNode on) { + var node = om.readTree(msg.getBodyAsStreamDecoded()); + if (node instanceof ObjectNode on) { return Optional.of(on); } - log.debug("Expected JSON Object but got: {}",jsonNode.getNodeType()); - } catch (IOException e) { + log.debug("Expected JSON Object but got: {}",node.getNodeType()); + } catch (Exception e) { log.debug("Error reading JSON: {}", e.getMessage()); } return empty(); From ece8d1dd579286643f0cb56522fea36dec60d84c Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 8 May 2026 13:28:30 +0200 Subject: [PATCH 09/43] docs: clarify parameter description in `checkLimit` method --- .../predic8/membrane/core/interceptor/ai/store/AiApiLimit.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java index c807931c76..eb82874372 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java @@ -30,7 +30,7 @@ public class AiApiLimit { * Checks if the user has enough tokens to make the request. * If there aren't enough tokens for the request, 0 or a negative number is returned. * - * @param tokensForNextRequest + * @param tokensForNextRequest Estimation of the number of tokens that will be used for the next request. * @return Estimated remaining tokens after this call. */ public long checkLimit(long tokensForNextRequest) { From 5dc551349de4dbed47efc4f9656d10ac57174255 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 8 May 2026 16:58:50 +0200 Subject: [PATCH 10/43] feat: add SSE event parsing and improve token handling - Introduced `SSEUtil` for parsing Server-Sent Events (SSE) from chunks. - Enhanced `AbstractAiApiRequest` to handle JSON requests conditionally. - Deprecated and replaced `max_output_tokens` usage in specific providers. - Improved stream support in `OpenAiAiRequest` with response usage tracking. - Refactored token limit logic in `OpenAIAPIInterceptor` for better flow. --- .../interceptor/ai/AbstractAiApiRequest.java | 9 +--- .../interceptor/ai/OpenAIAPIInterceptor.java | 44 ++++++++++++++++--- .../core/interceptor/ai/OpenAiApiUtil.java | 10 +++++ .../ai/provider/ClaudeAiRequest.java | 4 ++ .../ai/provider/GoogleAiRequest.java | 2 +- .../ai/provider/OpenAiAiRequest.java | 18 ++++++++ .../predic8/membrane/core/util/SSEUtil.java | 31 +++++++++++++ .../membrane/core/util/json/JsonUtil.java | 22 +++++++++- 8 files changed, 126 insertions(+), 14 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/util/SSEUtil.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java index 6f6b1e9193..5166dc7820 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java @@ -9,14 +9,14 @@ public abstract class AbstractAiApiRequest implements AiApiRequest { public static final String BEARER_PREFIX = "Bearer"; - private static final String MAX_OUTPUT_TOKENS = "max_output_tokens"; protected final Exchange exchange; protected ObjectNode json; public AbstractAiApiRequest(Exchange exchange) { this.exchange = exchange; - json = JsonUtil.getJsonObject(exchange.getRequest()).orElseThrow(() -> new RuntimeException("No JSON object request.")); + if (exchange.getRequest().isJSON()) + json = JsonUtil.getJsonObject(exchange.getRequest()).orElseThrow(() -> new RuntimeException("No JSON object request.")); } @Override @@ -25,11 +25,6 @@ public void setApiKey(String apiKey) { exchange.getRequest().getHeader().add(AUTHORIZATION, "Bearer " + apiKey); } - @Override - public void setMaxOutputTokens(int maxOutputTokens) { - json.put(MAX_OUTPUT_TOKENS, maxOutputTokens); - } - @Override public String getApiKey() { var ah = exchange.getRequest().getHeader().getAuthorization(); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java index 77af8dc276..347f68d718 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java @@ -4,11 +4,15 @@ import com.predic8.membrane.annot.MCChildElement; import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.http.AbstractMessageObserver; +import com.predic8.membrane.core.http.Chunk; import com.predic8.membrane.core.interceptor.AbstractInterceptor; import com.predic8.membrane.core.interceptor.Outcome; import com.predic8.membrane.core.interceptor.ai.provider.AiProvider; import com.predic8.membrane.core.interceptor.ai.store.AiApiStore; import com.predic8.membrane.core.interceptor.ai.store.AiApiUser; +import com.predic8.membrane.core.util.SSEUtil; +import com.predic8.membrane.core.util.json.JsonUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,7 +60,12 @@ public Outcome handleRequest(Exchange exc) { return RETURN; } - var inputTokens = aiReq.estimateInputTokens(); + if (!exc.getRequest().isPOSTRequest()) { + aiReq.setApiKey(apiKey); + return CONTINUE; + } + + long inputTokens = 0; if (store != null) { var opt = store.getUser(aiReq.getApiKey()); @@ -66,10 +75,13 @@ public Outcome handleRequest(Exchange exc) { } var user = opt.get(); log.debug("User: {}", user); - var remaining = store.checkLimit(user,inputTokens,maxOutputTokens); - if (remaining <= 0) { - exc.setResponse(tokenLimitExceeded()); - return RETURN; + if (exc.getRequest().isPOSTRequest()) { + inputTokens = aiReq.estimateInputTokens(); + var remaining = store.checkLimit(user, inputTokens, maxOutputTokens); + if (remaining <= 0) { + exc.setResponse(tokenLimitExceeded()); + return RETURN; + } } exc.setProperty(MEMBRANE_AI_USER, user); } @@ -102,6 +114,28 @@ public Outcome handleRequest(Exchange exc) { @Override public Outcome handleResponse(Exchange exc) { + var msg = exc.getResponse(); + + if (msg.isStream()) { + // Inspect each chunk as it arrives + msg.getBody().addObserver(new AbstractMessageObserver() { + @Override + public void bodyChunk(Chunk chunk) { + var event = SSEUtil.parseSSEvent(chunk); + if (event == null) { + return; + } + log.debug("SSE name: {}", event.name()); + if (OpenAiApiUtil.terminalEvent(event)) { + var jo = JsonUtil.getJsonObject(event.data()); + if (jo.isPresent()) { + log.debug("Usage: {}", jo.get().get("usage")); + } + } + } + }); + } + var aiRes = provider.getAiApiResponse(exc); if (aiRes.isError()) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java index 26df897f39..f3987bad07 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java @@ -2,6 +2,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.predic8.membrane.core.http.Response; +import com.predic8.membrane.core.util.SSEUtil; import java.util.Collection; @@ -13,6 +14,15 @@ public class OpenAiApiUtil { private static final ObjectMapper om = new ObjectMapper(); + /** + * Checks if the SSE Event is a terminal event. + * @param event SSE Event + * @return + */ + public static boolean terminalEvent(SSEUtil.SSEvent event) { + return "response.completed".equals(event.name()) || "response.incomplete".equals(event.name()); + } + public static Response modelNotAllowed(String model, Collection allowedModels) { return badRequest().json(createJson(new ErrorEnvelope( new ErrorBody( diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java index 96a4d39a4f..d15ffd5710 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java @@ -11,6 +11,10 @@ public ClaudeAiRequest(Exchange exchange) { super(exchange); } + public void setMaxOutputTokens(int maxOutputTokens) { + json.put("max_tokens", maxOutputTokens); + } + @Override public long estimateInputTokens() { // System prompt diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java index a1af123af7..c04aad9607 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java @@ -106,7 +106,7 @@ private long countText(JsonNode node) { @Override public void setMaxOutputTokens(int maxOutputTokens) { - getGenerationConfig().put("maxOutputTokens", maxOutputTokens); + getGenerationConfig().put("max_output_tokens", maxOutputTokens); } private ObjectNode getGenerationConfig() { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java index d52d8d06c9..8d3d2d9dde 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java @@ -8,6 +8,13 @@ public class OpenAiAiRequest extends AbstractAiApiRequest { public OpenAiAiRequest(Exchange exchange) { super(exchange); + + // Make sure that when streaming is enabled, the usage is included in the response. + if (json.path("stream").asBoolean(false)) { + if (isChatCompletionsRequest(exchange)) { + json.putObject("stream_options").put("include_usage", true); + } + } } @Override @@ -28,6 +35,13 @@ public long estimateInputTokens() { return Math.max(1, Math.round(chars / 4.0 * 1.15)); } + @Override + public void setMaxOutputTokens(int maxOutputTokens) { + // OpenAI deprecated max_tokens for newer models (o1, o3, gpt-5.x) in + // favor of max_completion_tokens. Older models still accept max_tokens. + json.put("max_output_tokens", maxOutputTokens); + } + private long estimateChatCompletitions() { long chars = 0; // Chat Completions API @@ -84,4 +98,8 @@ private long countJsonSize(JsonNode node) { } return node.toString().length(); } + + private boolean isChatCompletionsRequest(Exchange exchange) { + return exchange.getRequest().getUri().contains("/chat/completions"); + } } diff --git a/core/src/main/java/com/predic8/membrane/core/util/SSEUtil.java b/core/src/main/java/com/predic8/membrane/core/util/SSEUtil.java new file mode 100644 index 0000000000..0bad51528d --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/util/SSEUtil.java @@ -0,0 +1,31 @@ +package com.predic8.membrane.core.util; + +import com.predic8.membrane.core.http.Chunk; + +/** + * Util for Server Sent Events. + */ +public class SSEUtil { + + private SSEUtil() {} + + public record SSEvent(String name, String data) {} + + public static SSEvent parseSSEvent(Chunk chunk) { + var content = chunk.toString(); + String event = null; + String data = null; + + for (var line : content.split("\n")) { + line = line.trim(); + if (line.startsWith("name:")) { + event = line.substring("name:".length()).trim(); + } else if (line.startsWith("data:")) { + data = line.substring("data:".length()).trim(); + } + } + + if (event == null && data == null) return null; + return new SSEvent(event, data); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java index 21b78f7128..b6c99902e5 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java @@ -23,6 +23,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.InputStream; import java.math.BigDecimal; import java.math.BigInteger; import java.util.Optional; @@ -92,9 +93,27 @@ public static JsonNode scalarAsJson(String value) { return FACTORY.textNode(value); } + public static Optional getJsonObject(String s) { + try { + var node = om.readTree(s); + if (node instanceof ObjectNode on) { + return Optional.of(on); + } + log.debug("Expected JSON Object but got: {}",node.getNodeType()); + } catch (Exception e) { + log.debug("Error reading JSON: {}", e.getMessage()); + } + return empty(); + } + + public static Optional getJsonObject(Message msg) { + return getJsonObjectFromSteam(msg.getBodyAsStreamDecoded()); + } + + private static Optional getJsonObjectFromSteam(InputStream obj) { try { - var node = om.readTree(msg.getBodyAsStreamDecoded()); + var node = om.readTree(obj); if (node instanceof ObjectNode on) { return Optional.of(on); } @@ -105,6 +124,7 @@ public static Optional getJsonObject(Message msg) { return empty(); } + public static void setJsonBody(Message msg, ObjectNode json) { try { if (!msg.isJSON()) { From d49826ab091b66bb88801bb244aec00a1d52eb4c Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 14 May 2026 13:31:13 +0200 Subject: [PATCH 11/43] refactor: rename AI classes and interfaces for consistency - Renamed `AiApiRequest` to `LLMRequest` and `AiApiResponse` to `LLMResponse`. - Updated providers (`Google`, `OpenAI`, `Claude`) to align with `LLMProvider` interface. - Refactored `OpenAIAPIInterceptor` to `LLMGatewayInterceptor` and related utilities. - Removed `SSEUtil` and replaced with `SSEParser`. - Improved streaming and token usage handling in `AbstractLLMResponse`. --- .../interceptor/ai/AbstractAiApiResponse.java | 50 ------ ...piRequest.java => AbstractLLMRequest.java} | 4 +- .../interceptor/ai/AbstractLLMResponse.java | 95 +++++++++++ .../{OpenAiApiUtil.java => LLMApiUtil.java} | 6 +- ...ceptor.java => LLMGatewayInterceptor.java} | 57 ++----- .../ai/{AiApiRequest.java => LLMRequest.java} | 2 +- .../{AiApiResponse.java => LLMResponse.java} | 3 +- .../interceptor/ai/provider/AiProvider.java | 12 -- .../core/interceptor/ai/provider/Claude.java | 20 --- .../ai/provider/ClaudeAiResponse.java | 12 -- ...deAiRequest.java => ClaudeLLMRequest.java} | 6 +- .../ai/provider/ClaudeLLMResponse.java | 15 ++ .../ai/provider/ClaudeProvider.java | 22 +++ .../core/interceptor/ai/provider/Google.java | 20 --- ...leAiRequest.java => GoogleLLMRequest.java} | 6 +- ...AiResponse.java => GoogleLLMResponse.java} | 11 +- .../ai/provider/GoogleProvider.java | 22 +++ .../interceptor/ai/provider/LLMProvider.java | 14 ++ .../core/interceptor/ai/provider/OpenAI.java | 20 --- .../ai/provider/OpenAIProvider.java | 22 +++ .../ai/provider/OpenAiAiResponse.java | 12 -- ...AiAiRequest.java => OpenAiLLMRequest.java} | 6 +- .../ai/provider/OpenAiLLMResponse.java | 32 ++++ .../predic8/membrane/core/util/SSEUtil.java | 31 ---- .../membrane/core/util/http/SSEParser.java | 140 ++++++++++++++++ .../core/util/http/SSEParserTest.java | 149 ++++++++++++++++++ 26 files changed, 550 insertions(+), 239 deletions(-) delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/{AbstractAiApiRequest.java => AbstractLLMRequest.java} (92%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMResponse.java rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/{OpenAiApiUtil.java => LLMApiUtil.java} (95%) rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/{OpenAIAPIInterceptor.java => LLMGatewayInterceptor.java} (70%) rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/{AiApiRequest.java => LLMRequest.java} (90%) rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/{AiApiResponse.java => LLMResponse.java} (83%) delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiResponse.java rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/{ClaudeAiRequest.java => ClaudeLLMRequest.java} (89%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMResponse.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeProvider.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Google.java rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/{GoogleAiRequest.java => GoogleLLMRequest.java} (94%) rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/{GoogleAiResponse.java => GoogleLLMResponse.java} (64%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleProvider.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAIProvider.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiResponse.java rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/{OpenAiAiRequest.java => OpenAiLLMRequest.java} (94%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMResponse.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/util/SSEUtil.java create mode 100644 core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java create mode 100644 core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java deleted file mode 100644 index 7ba2ed49dd..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiResponse.java +++ /dev/null @@ -1,50 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.node.JsonNodeFactory; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.store.Usage; -import com.predic8.membrane.core.util.json.JsonUtil; - -public abstract class AbstractAiApiResponse implements AiApiResponse { - - protected final Exchange exchange; - protected ObjectNode json; - - public AbstractAiApiResponse(Exchange exchange) { - this.exchange = exchange; - json = JsonUtil.getJsonObject(exchange.getResponse()).orElse(JsonNodeFactory.instance.objectNode().put("error", "No JSON object response from model.")); - } - - @Override - public boolean isError() { - return json.get("error") != null && !json.get("error").isNull(); - } - - public Usage getUsage() { - - var usage = json.path("usage"); - - int inputTokens = getInputTokens(usage); - int outputTokens = getOutputTokens(usage); - int totalTokens = usage.path("total_tokens").asInt(inputTokens + outputTokens); - - return new Usage( - inputTokens, - outputTokens, - totalTokens - ); - } - - private static int getOutputTokens(JsonNode usage) { - return usage.path("output_tokens").asInt( - usage.path("completion_tokens").asInt(0) - ); - } - - private static int getInputTokens(JsonNode usage) { - return usage.path("input_tokens").asInt( - usage.path("prompt_tokens").asInt(0)); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMRequest.java similarity index 92% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMRequest.java index 5166dc7820..ba561a3a7e 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractAiApiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMRequest.java @@ -6,14 +6,14 @@ import static com.predic8.membrane.core.http.Header.AUTHORIZATION; -public abstract class AbstractAiApiRequest implements AiApiRequest { +public abstract class AbstractLLMRequest implements LLMRequest { public static final String BEARER_PREFIX = "Bearer"; protected final Exchange exchange; protected ObjectNode json; - public AbstractAiApiRequest(Exchange exchange) { + public AbstractLLMRequest(Exchange exchange) { this.exchange = exchange; if (exchange.getRequest().isJSON()) json = JsonUtil.getJsonObject(exchange.getRequest()).orElseThrow(() -> new RuntimeException("No JSON object request.")); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMResponse.java new file mode 100644 index 0000000000..c795ba4397 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMResponse.java @@ -0,0 +1,95 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.http.AbstractMessageObserver; +import com.predic8.membrane.core.http.Chunk; +import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.util.http.SSEParser; +import com.predic8.membrane.core.util.json.JsonUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.function.Consumer; + +public abstract class AbstractLLMResponse implements LLMResponse { + + private static final Logger log = LoggerFactory.getLogger(AbstractLLMResponse.class); + + protected final Exchange exchange; + protected ObjectNode json; + Consumer postProcessor; + + public AbstractLLMResponse(Exchange exchange, Consumer postProcessor) { + this.exchange = exchange; + this.postProcessor = postProcessor; + var msg = exchange.getResponse(); + + if (msg.isStream()) { + + var parser = new SSEParser("response.completed","response.incompleted"); + + msg.getBody().addObserver(new AbstractMessageObserver() { + @Override + public void bodyChunk(Chunk chunk) { + if (!parser.parse(chunk)) { + return; + } + + var events = parser.getEvents(); + var terminal = parser.getTerminalEvent(); + + log.debug("---------------------------------------------------------------"); + log.debug("Events: {}", events.size()); + events.forEach(e -> log.debug("Event: {}", e)); + log.debug("---------------------------------------------------------------"); + + + terminal.ifPresent(event -> { + json = JsonUtil.getJsonObject(event.data()) + .orElse(JsonNodeFactory.instance.objectNode() + .put("error", "No JSON object response from model.")); + + postProcessor.accept(AbstractLLMResponse.this); + }); + } + }); + } else { + json = JsonUtil.getJsonObject(exchange.getResponse()) + .orElse(JsonNodeFactory.instance.objectNode().put("error", "No JSON object response from model.")); + } + } + + @Override + public boolean isError() { + return json.get("error") != null && !json.get("error").isNull(); + } + + public Usage getUsage() { + + var usage = json.path("usage"); + + int inputTokens = getInputTokens(usage); + int outputTokens = getOutputTokens(usage); + int totalTokens = usage.path("total_tokens").asInt(inputTokens + outputTokens); + + return new Usage( + inputTokens, + outputTokens, + totalTokens + ); + } + + protected static int getOutputTokens(JsonNode usage) { + return usage.path("output_tokens").asInt( + usage.path("completion_tokens").asInt(0) + ); + } + + protected static int getInputTokens(JsonNode usage) { + return usage.path("input_tokens").asInt( + usage.path("prompt_tokens").asInt(0)); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java similarity index 95% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java index f3987bad07..19cdcc1f95 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAiApiUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java @@ -2,7 +2,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.predic8.membrane.core.http.Response; -import com.predic8.membrane.core.util.SSEUtil; +import com.predic8.membrane.core.util.http.SSEParser; import java.util.Collection; @@ -10,7 +10,7 @@ import static com.predic8.membrane.core.http.Response.badRequest; import static com.predic8.membrane.core.http.Response.unauthorized; -public class OpenAiApiUtil { +public class LLMApiUtil { private static final ObjectMapper om = new ObjectMapper(); @@ -19,7 +19,7 @@ public class OpenAiApiUtil { * @param event SSE Event * @return */ - public static boolean terminalEvent(SSEUtil.SSEvent event) { + public static boolean terminalEvent(SSEParser.SSEEvent event) { return "response.completed".equals(event.name()) || "response.incomplete".equals(event.name()); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java similarity index 70% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index 347f68d718..7af365e473 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/OpenAIAPIInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -4,15 +4,11 @@ import com.predic8.membrane.annot.MCChildElement; import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.http.AbstractMessageObserver; -import com.predic8.membrane.core.http.Chunk; import com.predic8.membrane.core.interceptor.AbstractInterceptor; import com.predic8.membrane.core.interceptor.Outcome; -import com.predic8.membrane.core.interceptor.ai.provider.AiProvider; +import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; import com.predic8.membrane.core.interceptor.ai.store.AiApiStore; import com.predic8.membrane.core.interceptor.ai.store.AiApiUser; -import com.predic8.membrane.core.util.SSEUtil; -import com.predic8.membrane.core.util.json.JsonUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -21,17 +17,17 @@ import static com.predic8.membrane.core.exceptions.ProblemDetails.user; import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; import static com.predic8.membrane.core.interceptor.Outcome.RETURN; -import static com.predic8.membrane.core.interceptor.ai.OpenAiApiUtil.*; +import static com.predic8.membrane.core.interceptor.ai.LLMApiUtil.*; import static com.predic8.membrane.core.util.json.JsonUtil.setJsonBody; @MCElement(name = "aiGateway") -public class OpenAIAPIInterceptor extends AbstractInterceptor { +public class LLMGatewayInterceptor extends AbstractInterceptor { - private static final Logger log = LoggerFactory.getLogger(OpenAIAPIInterceptor.class); + private static final Logger log = LoggerFactory.getLogger(LLMGatewayInterceptor.class); public static final String MEMBRANE_AI_USER = "membrane.ai.user"; - private AiProvider provider; + private LLMProvider provider; private String apiKey; private int maxOutputTokens; @@ -49,9 +45,9 @@ public void init() { @Override public Outcome handleRequest(Exchange exc) { - AiApiRequest aiReq; + LLMRequest aiReq; try { - aiReq = provider.getAiApiRequest(exc); + aiReq = provider.getLLMRequest(exc); } catch (Exception e) { user(router.getConfiguration().isProduction(),"AI Gateway") .title("Invalid request") @@ -114,36 +110,13 @@ public Outcome handleRequest(Exchange exc) { @Override public Outcome handleResponse(Exchange exc) { - var msg = exc.getResponse(); - - if (msg.isStream()) { - // Inspect each chunk as it arrives - msg.getBody().addObserver(new AbstractMessageObserver() { - @Override - public void bodyChunk(Chunk chunk) { - var event = SSEUtil.parseSSEvent(chunk); - if (event == null) { - return; - } - log.debug("SSE name: {}", event.name()); - if (OpenAiApiUtil.terminalEvent(event)) { - var jo = JsonUtil.getJsonObject(event.data()); - if (jo.isPresent()) { - log.debug("Usage: {}", jo.get().get("usage")); - } - } - } - }); - } - - var aiRes = provider.getAiApiResponse(exc); - - if (aiRes.isError()) - return CONTINUE; // pass error from AI API to client + var aiRes = provider.getLLMResponse(exc, res -> { + System.out.println("Usage: " + res.getUsage()); + if (store != null) { + store.store(exc.getProperty(MEMBRANE_AI_USER, AiApiUser.class), res.getUsage()); + } + }); - if (store != null) { - store.store(exc.getProperty(MEMBRANE_AI_USER, AiApiUser.class), aiRes.getUsage()); - } return CONTINUE; } @@ -197,12 +170,12 @@ public void setModels(List models) { this.models = models; } - public AiProvider getProvider() { + public LLMProvider getProvider() { return provider; } @MCChildElement(allowForeign = true) - public void setProvider(AiProvider provider) { + public void setProvider(LLMProvider provider) { this.provider = provider; } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMRequest.java similarity index 90% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMRequest.java index cec5f30079..d401969e8c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMRequest.java @@ -2,7 +2,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode; -public interface AiApiRequest { +public interface LLMRequest { String getModel(); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMResponse.java similarity index 83% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiResponse.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMResponse.java index b01e135626..45be2e0a46 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AiApiResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMResponse.java @@ -2,9 +2,10 @@ import com.predic8.membrane.core.interceptor.ai.store.Usage; -public interface AiApiResponse { +public interface LLMResponse { boolean isError(); Usage getUsage(); + } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java deleted file mode 100644 index 4b99f8f46d..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AiProvider.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AiApiRequest; -import com.predic8.membrane.core.interceptor.ai.AiApiResponse; - -public interface AiProvider { - - AiApiRequest getAiApiRequest(Exchange request); - AiApiResponse getAiApiResponse(Exchange request); - -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java deleted file mode 100644 index 93b60ba26f..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Claude.java +++ /dev/null @@ -1,20 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.annot.MCElement; -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AiApiRequest; -import com.predic8.membrane.core.interceptor.ai.AiApiResponse; - -@MCElement( name="claude") -public class Claude implements AiProvider { - - @Override - public AiApiRequest getAiApiRequest(Exchange exchange) { - return new ClaudeAiRequest(exchange); - } - - @Override - public AiApiResponse getAiApiResponse(Exchange request) { - return new ClaudeAiResponse(request); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiResponse.java deleted file mode 100644 index a500e5bfdf..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiResponse.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractAiApiResponse; - -public class ClaudeAiResponse extends AbstractAiApiResponse { - - public ClaudeAiResponse(Exchange exchange) { - super(exchange); - } - -} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMRequest.java similarity index 89% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMRequest.java index d15ffd5710..d0bbb0be1a 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeAiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMRequest.java @@ -1,13 +1,13 @@ package com.predic8.membrane.core.interceptor.ai.provider; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractAiApiRequest; +import com.predic8.membrane.core.interceptor.ai.AbstractLLMRequest; -public class ClaudeAiRequest extends AbstractAiApiRequest { +public class ClaudeLLMRequest extends AbstractLLMRequest { public static final String X_API_KEY = "x-api-key"; - public ClaudeAiRequest(Exchange exchange) { + public ClaudeLLMRequest(Exchange exchange) { super(exchange); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMResponse.java new file mode 100644 index 0000000000..25d4726c66 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMResponse.java @@ -0,0 +1,15 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.ai.LLMResponse; + +import java.util.function.Consumer; + +public class ClaudeLLMResponse extends AbstractLLMResponse { + + public ClaudeLLMResponse(Exchange exchange, Consumer postProcessor) { + super(exchange,postProcessor); + } + +} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeProvider.java new file mode 100644 index 0000000000..d415c09881 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeProvider.java @@ -0,0 +1,22 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.LLMRequest; +import com.predic8.membrane.core.interceptor.ai.LLMResponse; + +import java.util.function.Consumer; + +@MCElement( name="claude") +public class ClaudeProvider implements LLMProvider { + + @Override + public LLMRequest getLLMRequest(Exchange exchange) { + return new ClaudeLLMRequest(exchange); + } + + @Override + public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { + return new ClaudeLLMResponse(request, postProcessor); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Google.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Google.java deleted file mode 100644 index b87ed46c11..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/Google.java +++ /dev/null @@ -1,20 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.annot.MCElement; -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AiApiRequest; -import com.predic8.membrane.core.interceptor.ai.AiApiResponse; - -@MCElement( name="google",id = "google-ai-provider") -public class Google implements AiProvider { - - @Override - public AiApiRequest getAiApiRequest(Exchange exchange) { - return new GoogleAiRequest(exchange); - } - - @Override - public AiApiResponse getAiApiResponse(Exchange request) { - return new GoogleAiResponse(request); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMRequest.java similarity index 94% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMRequest.java index c04aad9607..693827ebcd 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMRequest.java @@ -3,13 +3,13 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractAiApiRequest; +import com.predic8.membrane.core.interceptor.ai.AbstractLLMRequest; -public class GoogleAiRequest extends AbstractAiApiRequest { +public class GoogleLLMRequest extends AbstractLLMRequest { public static final String X_GOOG_API_KEY = "x-goog-api-key"; - public GoogleAiRequest(Exchange exchange) { + public GoogleLLMRequest(Exchange exchange) { super(exchange); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMResponse.java similarity index 64% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiResponse.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMResponse.java index 66add9a43e..82fec4b0ba 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleAiResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMResponse.java @@ -1,13 +1,16 @@ package com.predic8.membrane.core.interceptor.ai.provider; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractAiApiResponse; +import com.predic8.membrane.core.interceptor.ai.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.ai.LLMResponse; import com.predic8.membrane.core.interceptor.ai.store.Usage; -public class GoogleAiResponse extends AbstractAiApiResponse { +import java.util.function.Consumer; - public GoogleAiResponse(Exchange exchange) { - super(exchange); +public class GoogleLLMResponse extends AbstractLLMResponse { + + public GoogleLLMResponse(Exchange exchange, Consumer postProcessor) { + super(exchange, postProcessor); } @Override diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleProvider.java new file mode 100644 index 0000000000..0cf43df16d --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleProvider.java @@ -0,0 +1,22 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.LLMRequest; +import com.predic8.membrane.core.interceptor.ai.LLMResponse; + +import java.util.function.Consumer; + +@MCElement( name="google",id = "google-ai-provider") +public class GoogleProvider implements LLMProvider { + + @Override + public LLMRequest getLLMRequest(Exchange exchange) { + return new GoogleLLMRequest(exchange); + } + + @Override + public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { + return new GoogleLLMResponse(request, postProcessor); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java new file mode 100644 index 0000000000..28cb31b5bc --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java @@ -0,0 +1,14 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.LLMRequest; +import com.predic8.membrane.core.interceptor.ai.LLMResponse; + +import java.util.function.Consumer; + +public interface LLMProvider { + + LLMRequest getLLMRequest(Exchange request); + LLMResponse getLLMResponse(Exchange request, Consumer postProcessor); + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java deleted file mode 100644 index 91b0ee8889..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAI.java +++ /dev/null @@ -1,20 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.annot.MCElement; -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AiApiRequest; -import com.predic8.membrane.core.interceptor.ai.AiApiResponse; - -@MCElement( name="openai") -public class OpenAI implements AiProvider{ - - @Override - public AiApiRequest getAiApiRequest(Exchange exchange) { - return new OpenAiAiRequest(exchange); - } - - @Override - public AiApiResponse getAiApiResponse(Exchange request) { - return new OpenAiAiResponse(request); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAIProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAIProvider.java new file mode 100644 index 0000000000..92dd80b3e2 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAIProvider.java @@ -0,0 +1,22 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.LLMRequest; +import com.predic8.membrane.core.interceptor.ai.LLMResponse; + +import java.util.function.Consumer; + +@MCElement( name="openai") +public class OpenAIProvider implements LLMProvider { + + @Override + public LLMRequest getLLMRequest(Exchange exchange) { + return new OpenAiLLMRequest(exchange); + } + + @Override + public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { + return new OpenAiLLMResponse(request, postProcessor); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiResponse.java deleted file mode 100644 index a20575cf68..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiResponse.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractAiApiResponse; - -public class OpenAiAiResponse extends AbstractAiApiResponse { - - public OpenAiAiResponse(Exchange exchange) { - super(exchange); - } - -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMRequest.java similarity index 94% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMRequest.java index 8d3d2d9dde..940b248a6c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiAiRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMRequest.java @@ -2,11 +2,11 @@ import com.fasterxml.jackson.databind.JsonNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractAiApiRequest; +import com.predic8.membrane.core.interceptor.ai.AbstractLLMRequest; -public class OpenAiAiRequest extends AbstractAiApiRequest { +public class OpenAiLLMRequest extends AbstractLLMRequest { - public OpenAiAiRequest(Exchange exchange) { + public OpenAiLLMRequest(Exchange exchange) { super(exchange); // Make sure that when streaming is enabled, the usage is included in the response. diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMResponse.java new file mode 100644 index 0000000000..31544e59e6 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMResponse.java @@ -0,0 +1,32 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.ai.LLMResponse; +import com.predic8.membrane.core.interceptor.ai.store.Usage; + +import java.util.function.Consumer; + +public class OpenAiLLMResponse extends AbstractLLMResponse { + + public OpenAiLLMResponse(Exchange exchange, Consumer postProcessor) { + super(exchange,postProcessor); + } + + @Override + public Usage getUsage() { + + var usage = json.path("response").path("usage"); + + int inputTokens = getInputTokens(usage); + int outputTokens = getOutputTokens(usage); + int totalTokens = usage.path("total_tokens").asInt(inputTokens + outputTokens); + + return new Usage( + inputTokens, + outputTokens, + totalTokens + ); + } + +} diff --git a/core/src/main/java/com/predic8/membrane/core/util/SSEUtil.java b/core/src/main/java/com/predic8/membrane/core/util/SSEUtil.java deleted file mode 100644 index 0bad51528d..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/util/SSEUtil.java +++ /dev/null @@ -1,31 +0,0 @@ -package com.predic8.membrane.core.util; - -import com.predic8.membrane.core.http.Chunk; - -/** - * Util for Server Sent Events. - */ -public class SSEUtil { - - private SSEUtil() {} - - public record SSEvent(String name, String data) {} - - public static SSEvent parseSSEvent(Chunk chunk) { - var content = chunk.toString(); - String event = null; - String data = null; - - for (var line : content.split("\n")) { - line = line.trim(); - if (line.startsWith("name:")) { - event = line.substring("name:".length()).trim(); - } else if (line.startsWith("data:")) { - data = line.substring("data:".length()).trim(); - } - } - - if (event == null && data == null) return null; - return new SSEvent(event, data); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java b/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java new file mode 100644 index 0000000000..1fe9f7fb7c --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java @@ -0,0 +1,140 @@ +package com.predic8.membrane.core.util.http; + +import com.predic8.membrane.core.http.Chunk; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +public final class SSEParser { + + private final Set terminalEventNames; + private final StringBuilder buffer = new StringBuilder(); + + private final List events = new ArrayList<>(); + + private String eventName; + private final StringBuilder data = new StringBuilder(); + + private boolean terminalFound; + + public SSEParser(String... terminalEventNames) { + this.terminalEventNames = Set.of(terminalEventNames); + } + + public boolean parse(Chunk chunk) { + if (terminalFound) { + return true; + } + + buffer.append(chunk.toString()); + + int lineEnd; + while ((lineEnd = findLineEnd(buffer)) >= 0) { + String line = readLine(buffer, lineEnd); + + if (line.isEmpty()) { + var event = buildEvent(); + resetEvent(); + + if (event != null) { + events.add(event); + + if (terminalEventNames.contains(event.name())) { + terminalFound = true; + return true; + } + } + + continue; + } + + parseLine(line); + } + + return false; + } + + public List getEvents() { + return List.copyOf(events); + } + + public Optional getTerminalEvent() { + if (!terminalFound || events.isEmpty()) { + return Optional.empty(); + } + + return Optional.of(events.getLast()); + } + + private SSEEvent buildEvent() { + if (eventName == null && data.isEmpty()) { + return null; + } + + return new SSEEvent(eventName, data.isEmpty() ? null : data.toString()); + } + + private void resetEvent() { + eventName = null; + data.setLength(0); + } + + private void parseLine(String line) { + if (line.startsWith(":")) { + return; + } + + int colon = line.indexOf(':'); + + String field = colon >= 0 ? line.substring(0, colon) : line; + String value = colon >= 0 ? line.substring(colon + 1) : ""; + + if (value.startsWith(" ")) { + value = value.substring(1); + } + + switch (field) { + case "event" -> eventName = value; + + case "data" -> { + if (!data.isEmpty()) { + data.append('\n'); + } + data.append(value); + } + + default -> { + // ignore id, retry, unknown fields + } + } + } + + private static int findLineEnd(StringBuilder buffer) { + for (int i = 0; i < buffer.length(); i++) { + char c = buffer.charAt(i); + if (c == '\n' || c == '\r') { + return i; + } + } + return -1; + } + + private static String readLine(StringBuilder buffer, int lineEnd) { + String line = buffer.substring(0, lineEnd); + + int removeUntil = lineEnd + 1; + + if (lineEnd + 1 < buffer.length() + && buffer.charAt(lineEnd) == '\r' + && buffer.charAt(lineEnd + 1) == '\n') { + removeUntil++; + } + + buffer.delete(0, removeUntil); + return line; + } + + public record SSEEvent(String name, String data) {} +} \ No newline at end of file diff --git a/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java b/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java new file mode 100644 index 0000000000..2421770991 --- /dev/null +++ b/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java @@ -0,0 +1,149 @@ +package com.predic8.membrane.core.util.http; + +import com.predic8.membrane.core.http.Chunk; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class SSEParserTest { + + @Test + void parsesSingleEvent() { + var parser = new SSEParser("done"); + + assertFalse(parser.parse(chunk(""" + event: message + data: hello + + """))); + + var events = parser.getEvents(); + + assertEquals(1, events.size()); + assertEquals("message", events.getFirst().name()); + assertEquals("hello", events.getFirst().data()); + assertTrue(parser.getTerminalEvent().isEmpty()); + } + + @Test + void parsesMultilineData() { + var parser = new SSEParser("done"); + + parser.parse(chunk(""" + event: message + data: first + data: second + + """)); + + assertEquals("first\nsecond", parser.getEvents().getFirst().data()); + } + + @Test + void parsesEventSplitAcrossChunks() { + var parser = new SSEParser("done"); + + assertFalse(parser.parse(chunk(""" + event: mes"""))); + + assertFalse(parser.parse(chunk(""" + sage + data: hel"""))); + + assertFalse(parser.parse(chunk(""" + lo + + """))); + + var event = parser.getEvents().getFirst(); + + assertEquals("message", event.name()); + assertEquals("hello", event.data()); + } + + @Test + void returnsTrueWhenTerminalEventIsFound() { + var parser = new SSEParser("done"); + + assertTrue(parser.parse(chunk(""" + event: done + data: {"usage":{"total_tokens":42}} + + """))); + + var terminal = parser.getTerminalEvent(); + + assertTrue(terminal.isPresent()); + assertEquals("done", terminal.get().name()); + assertEquals("{\"usage\":{\"total_tokens\":42}}", terminal.get().data()); + } + + @Test + void ignoresChunksAfterTerminalEvent() { + var parser = new SSEParser("done"); + + assertTrue(parser.parse(chunk(""" + event: done + data: final + + """))); + + assertTrue(parser.parse(chunk(""" + event: message + data: ignored + + """))); + + assertEquals(1, parser.getEvents().size()); + assertEquals("done", parser.getEvents().getFirst().name()); + } + + @Test + void ignoresCommentsAndUnknownFields() { + var parser = new SSEParser("done"); + + parser.parse(chunk(""" + : comment + id: 123 + retry: 1000 + event: message + data: hello + + """)); + + var event = parser.getEvents().getFirst(); + + assertEquals("message", event.name()); + assertEquals("hello", event.data()); + } + + @Test + void supportsCrLfLineEndings() { + var parser = new SSEParser("done"); + + parser.parse(chunk("event: message\r\ndata: hello\r\n\r\n")); + + var event = parser.getEvents().getFirst(); + + assertEquals("message", event.name()); + assertEquals("hello", event.data()); + } + + @Test + void returnsUnmodifiableEventsList() { + var parser = new SSEParser("done"); + + parser.parse(chunk(""" + event: message + data: hello + + """)); + + assertThrows(UnsupportedOperationException.class, + () -> parser.getEvents().add(new SSEParser.SSEEvent("x", "y"))); + } + + private static Chunk chunk(String content) { + return new Chunk(content.getBytes()); + } +} \ No newline at end of file From f0a3b2139b70c71c40d0c996a30bd58ddc27be6b Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 14 May 2026 13:32:21 +0200 Subject: [PATCH 12/43] refactor: remove unused `terminalEvent` method from `LLMApiUtil` --- .../membrane/core/interceptor/ai/LLMApiUtil.java | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java index 19cdcc1f95..2f86121750 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java @@ -2,7 +2,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.predic8.membrane.core.http.Response; -import com.predic8.membrane.core.util.http.SSEParser; import java.util.Collection; @@ -14,15 +13,6 @@ public class LLMApiUtil { private static final ObjectMapper om = new ObjectMapper(); - /** - * Checks if the SSE Event is a terminal event. - * @param event SSE Event - * @return - */ - public static boolean terminalEvent(SSEParser.SSEEvent event) { - return "response.completed".equals(event.name()) || "response.incomplete".equals(event.name()); - } - public static Response modelNotAllowed(String model, Collection allowedModels) { return badRequest().json(createJson(new ErrorEnvelope( new ErrorBody( From f11e491efc927709c5f3fc5078efb647baf0e717 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 18 May 2026 11:51:42 +0200 Subject: [PATCH 13/43] feat: refactor LLM APIs and improve SSE-driven event handling - Modularized LLM responses for `Claude` and `OpenAI` providers. - Replaced `LLMResponse` interface and `AbstractLLMResponse` with updated abstractions. - Added `ChatCompletionsSSEParser` for advanced SSE chunk handling. - Introduced specific SSE event classes: `ChatCompletionEvent`, `ChatCompletionDoneEvent`, `ResponsesApiEvent`. - Renamed and restructured classes for consistency in AI namespace. - Improved token usage tracking and event-based streaming. --- .../core/interceptor/ai/AbstractLLMEvent.java | 55 +++++ .../ai/ChatCompletionDoneEvent.java | 15 ++ .../interceptor/ai/ChatCompletionEvent.java | 70 ++++++ .../ai/ChatCompletionsSSEParser.java | 203 ++++++++++++++++++ .../interceptor/ai/LLMGatewayInterceptor.java | 4 +- .../core/interceptor/ai/LLMResponse.java | 11 - .../interceptor/ai/ResponsesApiEvent.java | 41 ++++ .../ai/provider/AbstractLLMMessage.java | 25 +++ .../ai/{ => provider}/AbstractLLMRequest.java | 34 ++- .../{ => provider}/AbstractLLMResponse.java | 76 +++---- .../ai/provider/ClaudeLLMRequest.java | 52 ----- .../ai/provider/ClaudeLLMResponse.java | 15 -- .../interceptor/ai/provider/LLMProvider.java | 2 - .../ai/{ => provider}/LLMRequest.java | 6 +- .../interceptor/ai/provider/LLMResponse.java | 18 ++ .../ai/provider/OpenAiLLMResponse.java | 32 --- .../ai/provider/claude/ClaudeLLMRequest.java | 89 ++++++++ .../ai/provider/claude/ClaudeLLMResponse.java | 64 ++++++ .../provider/{ => claude}/ClaudeProvider.java | 7 +- .../ai/provider/claude/ContentBlockDelta.java | 39 ++++ .../ai/provider/claude/ContentBlockStart.java | 23 ++ .../ai/provider/claude/MessageDelta.java | 70 ++++++ .../ai/provider/claude/ToolUse.java | 22 ++ .../{ => google}/GoogleLLMRequest.java | 4 +- .../{ => google}/GoogleLLMResponse.java | 18 +- .../provider/{ => google}/GoogleProvider.java | 7 +- .../provider/{ => openai}/OpenAIProvider.java | 7 +- .../{ => openai}/OpenAiLLMRequest.java | 15 +- .../ai/provider/openai/OpenAiLLMResponse.java | 69 ++++++ .../membrane/core/util/http/SSEParser.java | 30 ++- .../core/util/http/SSEParserTest.java | 18 +- 31 files changed, 952 insertions(+), 189 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMEvent.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionDoneEvent.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionsSSEParser.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMResponse.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMMessage.java rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/{ => provider}/AbstractLLMRequest.java (58%) rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/{ => provider}/AbstractLLMResponse.java (50%) delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMRequest.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMResponse.java rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/{ => provider}/LLMRequest.java (71%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMResponse.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMResponse.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/{ => claude}/ClaudeProvider.java (65%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockDelta.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockStart.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ToolUse.java rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/{ => google}/GoogleLLMRequest.java (95%) rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/{ => google}/GoogleLLMResponse.java (61%) rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/{ => google}/GoogleProvider.java (66%) rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/{ => openai}/OpenAIProvider.java (65%) rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/{ => openai}/OpenAiLLMRequest.java (87%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponse.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMEvent.java new file mode 100644 index 0000000000..d39b3ddafe --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMEvent.java @@ -0,0 +1,55 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.predic8.membrane.core.util.http.SSEParser; +import com.predic8.membrane.core.util.json.JsonUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class AbstractLLMEvent { + + private static final Logger log = LoggerFactory.getLogger(AbstractLLMEvent.class); + + protected static final ObjectMapper om = new ObjectMapper(); + + protected final JsonNode json; + + protected AbstractLLMEvent(JsonNode json) { + this.json = json; + } + + public abstract String getType(); + + public JsonNode getJson() { + return json; + } + + public static AbstractLLMEvent create(SSEParser.SSEEvent sse) { + + if ("[DONE]".equals(sse.data())) { + return new ChatCompletionDoneEvent(); + } + + var opt = JsonUtil.getJsonObject(sse.data()); + if (opt.isEmpty()) { + log.info("Unknown event format: {}", sse.data()); + } + + var json = opt.get(); + + // Responses API + if (json.has("type")) { + return new ResponsesApiEvent(json); + } + + // Chat Completions API + if ("chat.completion.chunk".equals(json.path("object").asText())) { + return new ChatCompletionEvent(json); + } + + log.debug("Unknown event format: {}", json); + + return null; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionDoneEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionDoneEvent.java new file mode 100644 index 0000000000..520118262a --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionDoneEvent.java @@ -0,0 +1,15 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.databind.node.NullNode; + +public class ChatCompletionDoneEvent extends AbstractLLMEvent { + + public ChatCompletionDoneEvent() { + super(NullNode.getInstance()); + } + + @Override + public String getType() { + return "chat.completion.done"; + } +} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java new file mode 100644 index 0000000000..7782c93f45 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java @@ -0,0 +1,70 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.databind.JsonNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ChatCompletionEvent extends AbstractLLMEvent { + + private static final Logger log = LoggerFactory.getLogger(ChatCompletionEvent.class); + + public ChatCompletionEvent(JsonNode json) { + super(json); + + parseChoices(json); + + var usage = json.path("usage"); + if (!usage.isNull()) { + var inputTokens = usage.get("prompt_tokens").asInt(); + var outputTokens = usage.get("completion_tokens").asInt(); + var totalTokens = usage.get("total_tokens").asInt(); + System.out.println("------------------------------totalTokens = " + totalTokens); + } + } + + + private static void parseChoices(JsonNode json) { + for (JsonNode choice : json.path("choices")) { + + JsonNode delta = choice.path("delta"); + + if (delta.has("content")) { + log.debug("Content delta: {}", + delta.path("content").asText()); + } + + if (delta.has("tool_calls")) { + + for (JsonNode tc : delta.path("tool_calls")) { + + JsonNode fn = tc.path("function"); + + if (fn.has("name")) { + log.debug("Tool call name delta: {}", + fn.path("name").asText()); + } + + if (fn.has("arguments")) { + log.debug("Tool call arguments delta: {}", + fn.path("arguments").asText()); + } + } + } + + String finishReason = choice.path("finish_reason").asText(null); + + if (finishReason != null && !"null".equals(finishReason)) { + log.debug("Finish reason: {}", finishReason); + } + } + } + + @Override + public String getType() { + return "chat.completion.chunk"; + } + + public JsonNode getChoices() { + return json.path("choices"); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionsSSEParser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionsSSEParser.java new file mode 100644 index 0000000000..7ac95f7456 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionsSSEParser.java @@ -0,0 +1,203 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.core.http.Chunk; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +public final class ChatCompletionsSSEParser { + + private static final Logger log = LoggerFactory.getLogger(ChatCompletionsSSEParser.class); + + private final StringBuilder buffer = new StringBuilder(); + private final List chunks = new ArrayList<>(); + private final StringBuilder data = new StringBuilder(); + + private boolean done; + + public boolean parse(Chunk chunk) { + if (done) + return true; + + log.debug("Parsing chat completions SSE chunk: {}", chunk); + + buffer.append(chunk.toString()); + + int lineEnd; + while ((lineEnd = findLineEnd(buffer)) >= 0) { + String line = readLine(buffer, lineEnd); + + if (line.isEmpty()) { + ChatCompletionChunk parsedChunk = buildChunk(); + resetEvent(); + + if (parsedChunk != null) { + if (parsedChunk.isDone()) { + done = true; + return true; + } + + chunks.add(parsedChunk); + } + + continue; + } + + parseLine(line); + } + + return false; + } + + public List getChunks() { + return List.copyOf(chunks); + } + + public Optional getLastChunk() { + if (chunks.isEmpty()) + return Optional.empty(); + + return Optional.of(chunks.get(chunks.size() - 1)); + } + + public boolean isDone() { + return done; + } + + private ChatCompletionChunk buildChunk() { + if (data.isEmpty()) + return null; + + String value = data.toString(); + + if ("[DONE]".equals(value)) + return ChatCompletionChunk.done(); + + return ChatCompletionChunk.json(value); + } + + private void resetEvent() { + data.setLength(0); + } + + private void parseLine(String line) { + if (line.startsWith(":")) + return; + + int colon = line.indexOf(':'); + + String field = colon >= 0 ? line.substring(0, colon) : line; + String value = colon >= 0 ? line.substring(colon + 1) : ""; + + if (value.startsWith(" ")) + value = value.substring(1); + + if ("data".equals(field)) { + if (!data.isEmpty()) + data.append('\n'); + + data.append(value); + } + } + + private static int findLineEnd(StringBuilder buffer) { + for (int i = 0; i < buffer.length(); i++) { + char c = buffer.charAt(i); + + if (c == '\n' || c == '\r') + return i; + } + + return -1; + } + + private static String readLine(StringBuilder buffer, int lineEnd) { + String line = buffer.substring(0, lineEnd); + + int removeUntil = lineEnd + 1; + + if (lineEnd + 1 < buffer.length() + && buffer.charAt(lineEnd) == '\r' + && buffer.charAt(lineEnd + 1) == '\n') { + removeUntil++; + } + + buffer.delete(0, removeUntil); + return line; + } + + public static final class ChatCompletionChunk { + + private static final ObjectMapper om = new ObjectMapper(); + + private final boolean done; + private final String data; + private ObjectNode json; + + private ChatCompletionChunk(boolean done, String data) { + this.done = done; + this.data = data; + } + + public static ChatCompletionChunk done() { + return new ChatCompletionChunk(true, null); + } + + public static ChatCompletionChunk json(String data) { + return new ChatCompletionChunk(false, data); + } + + public boolean isDone() { + return done; + } + + public String getData() { + return data; + } + + public ObjectNode json() { + if (done) + throw new IllegalStateException("[DONE] has no JSON body."); + + if (json != null) + return json; + + try { + json = (ObjectNode) om.readTree(data); + return json; + } catch (JsonProcessingException e) { + throw new RuntimeException("Could not parse chat completion chunk JSON.", e); + } + } + + public String contentDelta() { + if (done) + return null; + + return json() + .path("choices") + .path(0) + .path("delta") + .path("content") + .asText(null); + } + + public boolean hasToolCalls() { + if (done) + return false; + + return json() + .path("choices") + .path(0) + .path("delta") + .path("tool_calls") + .isArray(); + } + } +} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index 7af365e473..fa5f6748b3 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -7,6 +7,7 @@ import com.predic8.membrane.core.interceptor.AbstractInterceptor; import com.predic8.membrane.core.interceptor.Outcome; import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; +import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; import com.predic8.membrane.core.interceptor.ai.store.AiApiStore; import com.predic8.membrane.core.interceptor.ai.store.AiApiUser; import org.slf4j.Logger; @@ -103,6 +104,8 @@ public Outcome handleRequest(Exchange exc) { } } + log.debug("Tools: {}", aiReq.getTools()); + setJsonBody(exc.getRequest(), aiReq.getJson()); return CONTINUE; } @@ -111,7 +114,6 @@ public Outcome handleRequest(Exchange exc) { public Outcome handleResponse(Exchange exc) { var aiRes = provider.getLLMResponse(exc, res -> { - System.out.println("Usage: " + res.getUsage()); if (store != null) { store.store(exc.getProperty(MEMBRANE_AI_USER, AiApiUser.class), res.getUsage()); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMResponse.java deleted file mode 100644 index 45be2e0a46..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMResponse.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai; - -import com.predic8.membrane.core.interceptor.ai.store.Usage; - -public interface LLMResponse { - - boolean isError(); - - Usage getUsage(); - -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java new file mode 100644 index 0000000000..ab2c6fab1a --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java @@ -0,0 +1,41 @@ +package com.predic8.membrane.core.interceptor.ai; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ResponsesApiEvent extends AbstractLLMEvent { + + private static final Logger log = LoggerFactory.getLogger(ResponsesApiEvent.class); + + private final String type; + + public ResponsesApiEvent(JsonNode json) { + super(json); + + this.type = json.path("type").asText(); + + log.debug("Responses API event: {}", type); + + if ("response.output_item.done".equals(type)) { + + JsonNode item = json.path("item"); + + if (item.isObject()) { + ObjectNode on = (ObjectNode) item; + + if ("function_call".equals(on.path("type").asText())) { + log.info("Function call: {} with {}", + on.path("name").asText(), + on.path("arguments").asText()); + } + } + } + } + + @Override + public String getType() { + return type; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMMessage.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMMessage.java new file mode 100644 index 0000000000..488dabe3ce --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMMessage.java @@ -0,0 +1,25 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.exchange.Exchange; + +public class AbstractLLMMessage { + + protected final Exchange exchange; + + public enum API { COMPLETIONS, NORMAL } + + protected API api; + + protected AbstractLLMMessage(Exchange exchange) { + this.exchange = exchange; + api = getAPI(exchange); + } + + protected API getAPI(Exchange exchange) { + if (exchange.getRequest().getUri().contains("/chat/completions")) { + return API.COMPLETIONS; + } else { + return API.NORMAL; + } + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java similarity index 58% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java index ba561a3a7e..01d01003df 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java @@ -1,22 +1,46 @@ -package com.predic8.membrane.core.interceptor.ai; +package com.predic8.membrane.core.interceptor.ai.provider; +import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.util.json.JsonUtil; +import java.util.List; + import static com.predic8.membrane.core.http.Header.AUTHORIZATION; +import static java.util.Collections.emptyList; -public abstract class AbstractLLMRequest implements LLMRequest { +public abstract class AbstractLLMRequest extends AbstractLLMMessage implements LLMRequest { public static final String BEARER_PREFIX = "Bearer"; - protected final Exchange exchange; protected ObjectNode json; public AbstractLLMRequest(Exchange exchange) { - this.exchange = exchange; - if (exchange.getRequest().isJSON()) + super(exchange); + + if (exchange.getRequest().isJSON()) { json = JsonUtil.getJsonObject(exchange.getRequest()).orElseThrow(() -> new RuntimeException("No JSON object request.")); + + if (json.path("tools").isArray()) { + //System.out.println("Tools: " + json.path("tools")); + } + } + } + + public List getTools() { + var tools = getToolsNode(); + if (tools == null) + return emptyList(); + return tools.valueStream().map(n -> n.get("name").asText()).toList(); + } + + private ArrayNode getToolsNode() { + if (json == null) + return null; + if (json.path("tools").isArray()) + return (ArrayNode) json.path("tools"); + return null; } @Override diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java similarity index 50% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMResponse.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java index c795ba4397..d1550af0d2 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java @@ -1,4 +1,4 @@ -package com.predic8.membrane.core.interceptor.ai; +package com.predic8.membrane.core.interceptor.ai.provider; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.JsonNodeFactory; @@ -6,7 +6,6 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.http.AbstractMessageObserver; import com.predic8.membrane.core.http.Chunk; -import com.predic8.membrane.core.interceptor.ai.store.Usage; import com.predic8.membrane.core.util.http.SSEParser; import com.predic8.membrane.core.util.json.JsonUtil; import org.slf4j.Logger; @@ -14,72 +13,67 @@ import java.util.function.Consumer; -public abstract class AbstractLLMResponse implements LLMResponse { +public abstract class AbstractLLMResponse extends AbstractLLMMessage implements LLMResponse { private static final Logger log = LoggerFactory.getLogger(AbstractLLMResponse.class); - protected final Exchange exchange; protected ObjectNode json; - Consumer postProcessor; + protected Consumer postProcessor; public AbstractLLMResponse(Exchange exchange, Consumer postProcessor) { - this.exchange = exchange; + super(exchange); this.postProcessor = postProcessor; var msg = exchange.getResponse(); if (msg.isStream()) { - var parser = new SSEParser("response.completed","response.incompleted"); + log.debug("Streaming response."); + + var parser = new SSEParser(getTerminalEvents()); msg.getBody().addObserver(new AbstractMessageObserver() { @Override public void bodyChunk(Chunk chunk) { - if (!parser.parse(chunk)) { - return; - } - - var events = parser.getEvents(); - var terminal = parser.getTerminalEvent(); - - log.debug("---------------------------------------------------------------"); - log.debug("Events: {}", events.size()); - events.forEach(e -> log.debug("Event: {}", e)); - log.debug("---------------------------------------------------------------"); - - - terminal.ifPresent(event -> { - json = JsonUtil.getJsonObject(event.data()) - .orElse(JsonNodeFactory.instance.objectNode() - .put("error", "No JSON object response from model.")); - - postProcessor.accept(AbstractLLMResponse.this); - }); + processChunk(chunk, parser); } }); } else { json = JsonUtil.getJsonObject(exchange.getResponse()) .orElse(JsonNodeFactory.instance.objectNode().put("error", "No JSON object response from model.")); + postProcessor.accept(this); } } - @Override - public boolean isError() { - return json.get("error") != null && !json.get("error").isNull(); - } + protected void processChunk(Chunk chunk, SSEParser parser) { + // Wait for terminal chunk + if (!parser.parse(chunk)) { + return; + } - public Usage getUsage() { + // Now all chunks are parsed - var usage = json.path("usage"); + var events = parser.getEvents(); + var terminal = parser.getTerminalEvent(); - int inputTokens = getInputTokens(usage); - int outputTokens = getOutputTokens(usage); - int totalTokens = usage.path("total_tokens").asInt(inputTokens + outputTokens); + log.debug("Events: {}", events.size()); + events.forEach(this::process); - return new Usage( - inputTokens, - outputTokens, - totalTokens - ); + terminal.ifPresent(event -> { + // Terminal of old chat completion API + if ("[DONE]".equals(event.data())) + return; + json = JsonUtil.getJsonObject(event.data()) + .orElse(JsonNodeFactory.instance.objectNode() + .put("error", "No JSON object response from model.")); + + // All is read, call postProcessor + postProcessor.accept(AbstractLLMResponse.this); + }); + } + + @Override + public boolean isError() { + return json.get("error") != null && !json.get("error").isNull(); } protected static int getOutputTokens(JsonNode usage) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMRequest.java deleted file mode 100644 index d0bbb0be1a..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMRequest.java +++ /dev/null @@ -1,52 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractLLMRequest; - -public class ClaudeLLMRequest extends AbstractLLMRequest { - - public static final String X_API_KEY = "x-api-key"; - - public ClaudeLLMRequest(Exchange exchange) { - super(exchange); - } - - public void setMaxOutputTokens(int maxOutputTokens) { - json.put("max_tokens", maxOutputTokens); - } - - @Override - public long estimateInputTokens() { - // System prompt - long tokens = json.path("system").asText().length() / 4; - - // Messages - for (var message : json.path("messages")) { - var content = message.path("content"); - if (content.isTextual()) { - tokens += content.asText().length() / 4; - } else if (content.isArray()) { - for (var block : content) { - var type = block.path("type").asText(); - if (type.equals("text")) { - tokens += block.path("text").asText().length() / 4; - } else if (type.equals("image")) { - tokens += 1000; - } - } - } - } - return tokens; - } - - @Override - public String getApiKey() { - return exchange.getRequest().getHeader().getFirstValue(X_API_KEY); - } - - @Override - public void setApiKey(String apiKey) { - exchange.getRequest().getHeader().removeFields(X_API_KEY); - exchange.getRequest().getHeader().add(X_API_KEY, apiKey); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMResponse.java deleted file mode 100644 index 25d4726c66..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeLLMResponse.java +++ /dev/null @@ -1,15 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractLLMResponse; -import com.predic8.membrane.core.interceptor.ai.LLMResponse; - -import java.util.function.Consumer; - -public class ClaudeLLMResponse extends AbstractLLMResponse { - - public ClaudeLLMResponse(Exchange exchange, Consumer postProcessor) { - super(exchange,postProcessor); - } - -} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java index 28cb31b5bc..102ab072d1 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java @@ -1,8 +1,6 @@ package com.predic8.membrane.core.interceptor.ai.provider; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.LLMRequest; -import com.predic8.membrane.core.interceptor.ai.LLMResponse; import java.util.function.Consumer; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java similarity index 71% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java index d401969e8c..9b0c577ca9 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java @@ -1,7 +1,9 @@ -package com.predic8.membrane.core.interceptor.ai; +package com.predic8.membrane.core.interceptor.ai.provider; import com.fasterxml.jackson.databind.node.ObjectNode; +import java.util.List; + public interface LLMRequest { String getModel(); @@ -16,4 +18,6 @@ public interface LLMRequest { ObjectNode getJson(); + List getTools(); + } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMResponse.java new file mode 100644 index 0000000000..fd4979ca7e --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMResponse.java @@ -0,0 +1,18 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.util.http.SSEParser.SSEEvent; + +import java.util.Set; + +public interface LLMResponse { + + boolean isError(); + + Usage getUsage(); + + Set getTerminalEvents(); + + void process(SSEEvent event); + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMResponse.java deleted file mode 100644 index 31544e59e6..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMResponse.java +++ /dev/null @@ -1,32 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractLLMResponse; -import com.predic8.membrane.core.interceptor.ai.LLMResponse; -import com.predic8.membrane.core.interceptor.ai.store.Usage; - -import java.util.function.Consumer; - -public class OpenAiLLMResponse extends AbstractLLMResponse { - - public OpenAiLLMResponse(Exchange exchange, Consumer postProcessor) { - super(exchange,postProcessor); - } - - @Override - public Usage getUsage() { - - var usage = json.path("response").path("usage"); - - int inputTokens = getInputTokens(usage); - int outputTokens = getOutputTokens(usage); - int totalTokens = usage.path("total_tokens").asInt(inputTokens + outputTokens); - - return new Usage( - inputTokens, - outputTokens, - totalTokens - ); - } - -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java new file mode 100644 index 0000000000..236f74a3ae --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java @@ -0,0 +1,89 @@ +package com.predic8.membrane.core.interceptor.ai.provider.claude; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ClaudeLLMRequest extends AbstractLLMRequest { + + private static final Logger log = LoggerFactory.getLogger(ClaudeLLMRequest.class); + + public static final String X_API_KEY = "x-api-key"; + + public ClaudeLLMRequest(Exchange exchange) { + super(exchange); + + exchange.getRequest().getHeader().setValue( "Accept-Encoding","identity"); + } + + public void setMaxOutputTokens(int maxOutputTokens) { + + // Thinking needs a certain number of tokens + if (maxOutputTokens < 2048 && isThinking()) { + log.info("maxOutputTokens is {}. Too low for thinking. Disabling thinking.", maxOutputTokens); + disableThinking(); + } + + json.put("max_tokens", maxOutputTokens); + + if (isThinking()) { + var thinking = (ObjectNode) json.path("thinking"); + if (!thinking.path("budget_tokens").isNull()) { + var budgetTokens = thinking.path("budget_tokens").asInt(); + if (budgetTokens >= maxOutputTokens) { + // budget_tokens must be smaller than max_tokens + // value might vary between models + thinking.put("budget_tokens", Math.min(maxOutputTokens / 2, 1024)); + } + } + + } + } + + @Override + public long estimateInputTokens() { + // System prompt + long tokens = json.path("system").asText().length() / 4; + + // Messages + for (var message : json.path("messages")) { + var content = message.path("content"); + if (content.isTextual()) { + tokens += content.asText().length() / 4; + } else if (content.isArray()) { + for (var block : content) { + var type = block.path("type").asText(); + if (type.equals("text")) { + tokens += block.path("text").asText().length() / 4; + } else if (type.equals("image")) { + tokens += 1000; + } + } + } + } + return tokens; + } + + private boolean isThinking() { + var thinking = json.path("thinking"); + return thinking.isObject() && "enabled".equals(thinking.path("type").asText()); + } + + private void disableThinking() { + var thinking = json.putObject("thinking"); + thinking.put("type", "disabled"); + } + + @Override + public String getApiKey() { + return exchange.getRequest().getHeader().getFirstValue(X_API_KEY); + } + + @Override + public void setApiKey(String apiKey) { + exchange.getRequest().getHeader().removeFields(X_API_KEY); + exchange.getRequest().getHeader().add(X_API_KEY, apiKey); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java new file mode 100644 index 0000000000..c46045cda5 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java @@ -0,0 +1,64 @@ +package com.predic8.membrane.core.interceptor.ai.provider.claude; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; +import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.util.http.SSEParser.SSEEvent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Set; +import java.util.function.Consumer; + +public class ClaudeLLMResponse extends AbstractLLMResponse { + + private static final Logger log = LoggerFactory.getLogger(ClaudeLLMResponse.class); + + private Usage usage; + + private StringBuffer inputJson = new StringBuffer(); + + private String tool; + + public ClaudeLLMResponse(Exchange exchange, Consumer postProcessor) { + super(exchange,postProcessor); + } + + @Override + public Set getTerminalEvents() { + return Set.of("message_stop"); + } + + @Override + public void process(SSEEvent event) { + log.debug("Event: {}", event); + + if ("content_block_start".equals(event.name())) { + var cbs = ContentBlockStart.from(event.json()); + if (cbs.getToolUse() != null) { + tool = cbs.getToolUse().getName(); + } + } + if ("message_delta".equals(event.name())) { + var md = MessageDelta.from(event.json()); + log.debug("Message delta: {}", md); + if (md.getUsage() != null) { + usage = md.getUsage(); + if (tool != null) + log.debug("Tool {} with {}", tool, inputJson.toString()); + } + } + if ("content_block_delta".equals(event.name())) { + var cbd = ContentBlockDelta.from(event.json()); + if (cbd.isInputJsonDelta()) { + inputJson.append(cbd.getPartialJson()); + } + } + } + + @Override + public Usage getUsage() { + return usage; + } +} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java similarity index 65% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeProvider.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java index d415c09881..471440f214 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/ClaudeProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java @@ -1,9 +1,10 @@ -package com.predic8.membrane.core.interceptor.ai.provider; +package com.predic8.membrane.core.interceptor.ai.provider.claude; import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.LLMRequest; -import com.predic8.membrane.core.interceptor.ai.LLMResponse; +import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; +import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; import java.util.function.Consumer; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockDelta.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockDelta.java new file mode 100644 index 0000000000..f8d32c1c97 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockDelta.java @@ -0,0 +1,39 @@ +package com.predic8.membrane.core.interceptor.ai.provider.claude; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; + +public class ContentBlockDelta { + + private int index; + private String deltaType; + private String partialJson; + + public static ContentBlockDelta from(ObjectNode on) { + var cbd = new ContentBlockDelta(); + + cbd.index = on.path("index").asInt(); + + JsonNode delta = on.path("delta"); + cbd.deltaType = delta.path("type").asText(null); + cbd.partialJson = delta.path("partial_json").asText(""); + + return cbd; + } + + public boolean isInputJsonDelta() { + return "input_json_delta".equals(deltaType); + } + + public int getIndex() { + return index; + } + + public String getDeltaType() { + return deltaType; + } + + public String getPartialJson() { + return partialJson; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockStart.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockStart.java new file mode 100644 index 0000000000..b98f1d5827 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockStart.java @@ -0,0 +1,23 @@ +package com.predic8.membrane.core.interceptor.ai.provider.claude; + +import com.fasterxml.jackson.databind.node.ObjectNode; + +public class ContentBlockStart { + + private ToolUse toolUse; + + public static ContentBlockStart from(ObjectNode on) { + var cbs = new ContentBlockStart(); + var cb = (ObjectNode) on.path("content_block"); + + if ("tool_use".equals(cb.path("type").asText())) { + cbs.toolUse = ToolUse.from(cb); + } + + return cbs; + } + + public ToolUse getToolUse() { + return toolUse; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java new file mode 100644 index 0000000000..c2f3106a53 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java @@ -0,0 +1,70 @@ +package com.predic8.membrane.core.interceptor.ai.provider.claude; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.core.interceptor.ai.store.Usage; + +public class MessageDelta { + + private String stopReason; + private int inputTokens; + private int outputTokens; + private int cacheCreationInputTokens; + private int cacheReadInputTokens; + + private Usage usage; + + public static MessageDelta from(ObjectNode on) { + var md = new MessageDelta(); + + JsonNode delta = on.path("delta"); + md.stopReason = delta.path("stop_reason").asText(null); + + JsonNode u = on.path("usage"); + if (u.isObject()) { + md.inputTokens = u.path("input_tokens").asInt(0); + md.outputTokens = u.path("output_tokens").asInt(0); + md.cacheCreationInputTokens = u.path("cache_creation_input_tokens").asInt(0); + md.cacheReadInputTokens = u.path("cache_read_input_tokens").asInt(0); + + md.usage = new Usage(md.inputTokens, md.outputTokens, md.inputTokens + md.outputTokens); + } + + return md; + } + + public String getStopReason() { + return stopReason; + } + + public int getInputTokens() { + return inputTokens; + } + + public int getOutputTokens() { + return outputTokens; + } + + public int getCacheCreationInputTokens() { + return cacheCreationInputTokens; + } + + public int getCacheReadInputTokens() { + return cacheReadInputTokens; + } + + public Usage getUsage() { + return usage; + } + + @Override + public String toString() { + return "MessageDelta{" + + "stopReason='" + stopReason + '\'' + + ", inputTokens=" + inputTokens + + ", outputTokens=" + outputTokens + + ", cacheCreationInputTokens=" + cacheCreationInputTokens + + ", cacheReadInputTokens=" + cacheReadInputTokens + + '}'; + } +} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ToolUse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ToolUse.java new file mode 100644 index 0000000000..59ef545e68 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ToolUse.java @@ -0,0 +1,22 @@ +package com.predic8.membrane.core.interceptor.ai.provider.claude; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ToolUse { + + private static final Logger log = LoggerFactory.getLogger(ToolUse.class); + + private String name; + + public static ToolUse from(ObjectNode on) { + var tu = new ToolUse(); + tu.name = on.path("name").asText(); + return tu; + } + + public String getName() { + return name; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java similarity index 95% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java index 693827ebcd..19ce54f6b2 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java @@ -1,9 +1,9 @@ -package com.predic8.membrane.core.interceptor.ai.provider; +package com.predic8.membrane.core.interceptor.ai.provider.google; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractLLMRequest; +import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMRequest; public class GoogleLLMRequest extends AbstractLLMRequest { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMResponse.java similarity index 61% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMResponse.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMResponse.java index 82fec4b0ba..5b0adb2144 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMResponse.java @@ -1,10 +1,12 @@ -package com.predic8.membrane.core.interceptor.ai.provider; +package com.predic8.membrane.core.interceptor.ai.provider.google; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractLLMResponse; -import com.predic8.membrane.core.interceptor.ai.LLMResponse; +import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.util.http.SSEParser; +import java.util.Set; import java.util.function.Consumer; public class GoogleLLMResponse extends AbstractLLMResponse { @@ -27,4 +29,14 @@ public Usage getUsage() { totalTokens ); } + + @Override + public Set getTerminalEvents() { + return Set.of("response.completed","response.incompleted"); + } + + @Override + public void process(SSEParser.SSEEvent event) { + + } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java similarity index 66% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleProvider.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java index 0cf43df16d..9a7d27e31d 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/GoogleProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java @@ -1,9 +1,10 @@ -package com.predic8.membrane.core.interceptor.ai.provider; +package com.predic8.membrane.core.interceptor.ai.provider.google; import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.LLMRequest; -import com.predic8.membrane.core.interceptor.ai.LLMResponse; +import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; +import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; import java.util.function.Consumer; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAIProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java similarity index 65% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAIProvider.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java index 92dd80b3e2..d92e0ba254 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAIProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java @@ -1,9 +1,10 @@ -package com.predic8.membrane.core.interceptor.ai.provider; +package com.predic8.membrane.core.interceptor.ai.provider.openai; import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.LLMRequest; -import com.predic8.membrane.core.interceptor.ai.LLMResponse; +import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; +import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; import java.util.function.Consumer; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMRequest.java similarity index 87% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMRequest.java index 940b248a6c..eafe6cd0f7 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/OpenAiLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMRequest.java @@ -1,14 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai.provider; +package com.predic8.membrane.core.interceptor.ai.provider.openai; import com.fasterxml.jackson.databind.JsonNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractLLMRequest; +import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMRequest; public class OpenAiLLMRequest extends AbstractLLMRequest { public OpenAiLLMRequest(Exchange exchange) { super(exchange); + if (json == null) { + return; + } + // Make sure that when streaming is enabled, the usage is included in the response. if (json.path("stream").asBoolean(false)) { if (isChatCompletionsRequest(exchange)) { @@ -39,7 +43,12 @@ public long estimateInputTokens() { public void setMaxOutputTokens(int maxOutputTokens) { // OpenAI deprecated max_tokens for newer models (o1, o3, gpt-5.x) in // favor of max_completion_tokens. Older models still accept max_tokens. - json.put("max_output_tokens", maxOutputTokens); + if (api == API.NORMAL) { + json.put("max_output_tokens", maxOutputTokens); + } + if (api == API.COMPLETIONS) { + json.put("max_completion_tokens", maxOutputTokens); + } } private long estimateChatCompletitions() { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponse.java new file mode 100644 index 0000000000..76e4c65c9b --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponse.java @@ -0,0 +1,69 @@ +package com.predic8.membrane.core.interceptor.ai.provider.openai; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractLLMEvent; +import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; +import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.util.http.SSEParser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Set; +import java.util.function.Consumer; + +public class OpenAiLLMResponse extends AbstractLLMResponse { + + private static final Logger log = LoggerFactory.getLogger(OpenAiLLMResponse.class); + + public OpenAiLLMResponse(Exchange exchange, Consumer postProcessor) { + super(exchange,postProcessor); + } + + @Override + public Usage getUsage() { + + int inputTokens = 0; + int outputTokens = 0; + int totalTokens = 0; + + // Responses API + if (!json.path("response").isMissingNode()) { + var usage = json.path("response").path("usage"); + + getInputTokens(usage); + getOutputTokens(usage); + usage.path("total_tokens").asInt(inputTokens + outputTokens); + } else { + // Older chat completions API + inputTokens = json.path("usage").path("prompt_tokens").asInt(0); + outputTokens = json.path("usage").path("completion_tokens").asInt(0); + totalTokens = json.path("total_tokens").asInt(inputTokens + outputTokens); + + } + + return new Usage( + inputTokens, + outputTokens, + totalTokens + ); + } + + @Override + public Set getTerminalEvents() { + return Set.of("response.completed","response.incompleted"); + } + + @Override + public void process(SSEParser.SSEEvent e) { + log.debug("Event: {}", e.name()); + log.debug("Data: {}", e.data()); + var event = AbstractLLMEvent.create(e); + System.out.println(event); + + var json = event.getJson(); + if (!json.path("usage").isNull()) { + + } + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java b/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java index 1fe9f7fb7c..acddbc7428 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java +++ b/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java @@ -1,6 +1,11 @@ package com.predic8.membrane.core.util.http; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.http.Chunk; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; @@ -9,6 +14,8 @@ public final class SSEParser { + private static final Logger log = LoggerFactory.getLogger(SSEParser.class); + private final Set terminalEventNames; private final StringBuilder buffer = new StringBuilder(); @@ -19,8 +26,8 @@ public final class SSEParser { private boolean terminalFound; - public SSEParser(String... terminalEventNames) { - this.terminalEventNames = Set.of(terminalEventNames); + public SSEParser(Set terminalEventNames) { + this.terminalEventNames = terminalEventNames; } public boolean parse(Chunk chunk) { @@ -28,6 +35,8 @@ public boolean parse(Chunk chunk) { return true; } + log.debug("Parsing SSE chunk: {}", chunk); + buffer.append(chunk.toString()); int lineEnd; @@ -41,7 +50,7 @@ public boolean parse(Chunk chunk) { if (event != null) { events.add(event); - if (terminalEventNames.contains(event.name())) { + if ((event.name() != null && terminalEventNames.contains(event.name())) || "[DONE]".equals(event.data())) { terminalFound = true; return true; } @@ -136,5 +145,18 @@ private static String readLine(StringBuilder buffer, int lineEnd) { return line; } - public record SSEEvent(String name, String data) {} + public record SSEEvent(String name, String data) { + + private static final ObjectMapper om = new ObjectMapper(); + + public ObjectNode json() { + try { + return (ObjectNode) om.readTree(data); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + } + } \ No newline at end of file diff --git a/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java b/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java index 2421770991..c08ecd3a09 100644 --- a/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java +++ b/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java @@ -3,13 +3,15 @@ import com.predic8.membrane.core.http.Chunk; import org.junit.jupiter.api.Test; +import java.util.Set; + import static org.junit.jupiter.api.Assertions.*; class SSEParserTest { @Test void parsesSingleEvent() { - var parser = new SSEParser("done"); + var parser = new SSEParser(Set.of("done")); assertFalse(parser.parse(chunk(""" event: message @@ -27,7 +29,7 @@ void parsesSingleEvent() { @Test void parsesMultilineData() { - var parser = new SSEParser("done"); + var parser = new SSEParser(Set.of("done")); parser.parse(chunk(""" event: message @@ -41,7 +43,7 @@ void parsesMultilineData() { @Test void parsesEventSplitAcrossChunks() { - var parser = new SSEParser("done"); + var parser = new SSEParser(Set.of("done")); assertFalse(parser.parse(chunk(""" event: mes"""))); @@ -63,7 +65,7 @@ void parsesEventSplitAcrossChunks() { @Test void returnsTrueWhenTerminalEventIsFound() { - var parser = new SSEParser("done"); + var parser = new SSEParser(Set.of("done")); assertTrue(parser.parse(chunk(""" event: done @@ -80,7 +82,7 @@ void returnsTrueWhenTerminalEventIsFound() { @Test void ignoresChunksAfterTerminalEvent() { - var parser = new SSEParser("done"); + var parser = new SSEParser(Set.of("done")); assertTrue(parser.parse(chunk(""" event: done @@ -100,7 +102,7 @@ void ignoresChunksAfterTerminalEvent() { @Test void ignoresCommentsAndUnknownFields() { - var parser = new SSEParser("done"); + var parser = new SSEParser(Set.of("done")); parser.parse(chunk(""" : comment @@ -119,7 +121,7 @@ void ignoresCommentsAndUnknownFields() { @Test void supportsCrLfLineEndings() { - var parser = new SSEParser("done"); + var parser = new SSEParser(Set.of("done")); parser.parse(chunk("event: message\r\ndata: hello\r\n\r\n")); @@ -131,7 +133,7 @@ void supportsCrLfLineEndings() { @Test void returnsUnmodifiableEventsList() { - var parser = new SSEParser("done"); + var parser = new SSEParser(Set.of("done")); parser.parse(chunk(""" event: message From b124aa06da9e93fe6ca71ff4a455275f950af3cb Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 18 May 2026 13:48:16 +0200 Subject: [PATCH 14/43] refactor: remove `ChatCompletionsSSEParser` and unused token usage tracking logic - Deleted `ChatCompletionsSSEParser` and related classes/methods. - Simplified `ChatCompletionEvent` by removing token usage parsing. - Updated tool extraction logic in `AbstractLLMRequest` to handle function-specific tools. --- .../interceptor/ai/ChatCompletionEvent.java | 7 - .../ai/ChatCompletionsSSEParser.java | 203 ------------------ .../ai/provider/AbstractLLMRequest.java | 8 +- .../ai/provider/claude/ClaudeLLMResponse.java | 2 +- 4 files changed, 8 insertions(+), 212 deletions(-) delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionsSSEParser.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java index 7782c93f45..86531144d9 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java @@ -13,13 +13,6 @@ public ChatCompletionEvent(JsonNode json) { parseChoices(json); - var usage = json.path("usage"); - if (!usage.isNull()) { - var inputTokens = usage.get("prompt_tokens").asInt(); - var outputTokens = usage.get("completion_tokens").asInt(); - var totalTokens = usage.get("total_tokens").asInt(); - System.out.println("------------------------------totalTokens = " + totalTokens); - } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionsSSEParser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionsSSEParser.java deleted file mode 100644 index 7ac95f7456..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionsSSEParser.java +++ /dev/null @@ -1,203 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.predic8.membrane.core.http.Chunk; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; - -public final class ChatCompletionsSSEParser { - - private static final Logger log = LoggerFactory.getLogger(ChatCompletionsSSEParser.class); - - private final StringBuilder buffer = new StringBuilder(); - private final List chunks = new ArrayList<>(); - private final StringBuilder data = new StringBuilder(); - - private boolean done; - - public boolean parse(Chunk chunk) { - if (done) - return true; - - log.debug("Parsing chat completions SSE chunk: {}", chunk); - - buffer.append(chunk.toString()); - - int lineEnd; - while ((lineEnd = findLineEnd(buffer)) >= 0) { - String line = readLine(buffer, lineEnd); - - if (line.isEmpty()) { - ChatCompletionChunk parsedChunk = buildChunk(); - resetEvent(); - - if (parsedChunk != null) { - if (parsedChunk.isDone()) { - done = true; - return true; - } - - chunks.add(parsedChunk); - } - - continue; - } - - parseLine(line); - } - - return false; - } - - public List getChunks() { - return List.copyOf(chunks); - } - - public Optional getLastChunk() { - if (chunks.isEmpty()) - return Optional.empty(); - - return Optional.of(chunks.get(chunks.size() - 1)); - } - - public boolean isDone() { - return done; - } - - private ChatCompletionChunk buildChunk() { - if (data.isEmpty()) - return null; - - String value = data.toString(); - - if ("[DONE]".equals(value)) - return ChatCompletionChunk.done(); - - return ChatCompletionChunk.json(value); - } - - private void resetEvent() { - data.setLength(0); - } - - private void parseLine(String line) { - if (line.startsWith(":")) - return; - - int colon = line.indexOf(':'); - - String field = colon >= 0 ? line.substring(0, colon) : line; - String value = colon >= 0 ? line.substring(colon + 1) : ""; - - if (value.startsWith(" ")) - value = value.substring(1); - - if ("data".equals(field)) { - if (!data.isEmpty()) - data.append('\n'); - - data.append(value); - } - } - - private static int findLineEnd(StringBuilder buffer) { - for (int i = 0; i < buffer.length(); i++) { - char c = buffer.charAt(i); - - if (c == '\n' || c == '\r') - return i; - } - - return -1; - } - - private static String readLine(StringBuilder buffer, int lineEnd) { - String line = buffer.substring(0, lineEnd); - - int removeUntil = lineEnd + 1; - - if (lineEnd + 1 < buffer.length() - && buffer.charAt(lineEnd) == '\r' - && buffer.charAt(lineEnd + 1) == '\n') { - removeUntil++; - } - - buffer.delete(0, removeUntil); - return line; - } - - public static final class ChatCompletionChunk { - - private static final ObjectMapper om = new ObjectMapper(); - - private final boolean done; - private final String data; - private ObjectNode json; - - private ChatCompletionChunk(boolean done, String data) { - this.done = done; - this.data = data; - } - - public static ChatCompletionChunk done() { - return new ChatCompletionChunk(true, null); - } - - public static ChatCompletionChunk json(String data) { - return new ChatCompletionChunk(false, data); - } - - public boolean isDone() { - return done; - } - - public String getData() { - return data; - } - - public ObjectNode json() { - if (done) - throw new IllegalStateException("[DONE] has no JSON body."); - - if (json != null) - return json; - - try { - json = (ObjectNode) om.readTree(data); - return json; - } catch (JsonProcessingException e) { - throw new RuntimeException("Could not parse chat completion chunk JSON.", e); - } - } - - public String contentDelta() { - if (done) - return null; - - return json() - .path("choices") - .path(0) - .path("delta") - .path("content") - .asText(null); - } - - public boolean hasToolCalls() { - if (done) - return false; - - return json() - .path("choices") - .path(0) - .path("delta") - .path("tool_calls") - .isArray(); - } - } -} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java index 01d01003df..a28b797d05 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java @@ -32,7 +32,13 @@ public List getTools() { var tools = getToolsNode(); if (tools == null) return emptyList(); - return tools.valueStream().map(n -> n.get("name").asText()).toList(); + return tools.valueStream().map(n -> { + // Chat completion + if (n.has("function")) { + return n.get("function").get("name").asText(); + } + return n.get("name").asText(); + }).toList(); } private ArrayNode getToolsNode() { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java index c46045cda5..0e44dc8c4b 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java @@ -17,7 +17,7 @@ public class ClaudeLLMResponse extends AbstractLLMResponse { private Usage usage; - private StringBuffer inputJson = new StringBuffer(); + private final StringBuffer inputJson = new StringBuffer(); private String tool; From 46ce127a787c72d4a4737684b85fd7280958ad92 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 18 May 2026 15:29:02 +0200 Subject: [PATCH 15/43] chore: add TODO for handling client-provided API key if config key is absent --- .../membrane/core/interceptor/ai/LLMGatewayInterceptor.java | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index fa5f6748b3..5af30bc600 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -83,6 +83,7 @@ public Outcome handleRequest(Exchange exc) { exc.setProperty(MEMBRANE_AI_USER, user); } + // TODO if no apiKey in config => use key from client aiReq.setApiKey(apiKey); if (maxOutputTokens != 0) { From 5d7be1e571e0c77ccc00eb8eea1171f351013f24 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 19 May 2026 11:22:58 +0200 Subject: [PATCH 16/43] feat: enhance LLM response handling with modular classes and improved usage tracking - Added `OpenAiLLMResponsesAPIResponse` for handling OpenAI Responses API. - Refactored `OpenAiLLMResponse` to `OpenAiChatCompletionsLLMResponse`. - Improved token usage calculations and SSE event processing. - Updated `OpenAIProvider` to differentiate between Responses API and Chat Completions. --- .../interceptor/ai/LLMGatewayInterceptor.java | 5 +- .../interceptor/ai/ResponsesApiEvent.java | 4 +- .../ai/provider/AbstractLLMRequest.java | 7 ++ .../ai/provider/AbstractLLMResponse.java | 4 +- .../ai/provider/claude/MessageDelta.java | 5 +- .../ai/provider/openai/OpenAIProvider.java | 11 ++- .../OpenAiChatCompletionsLLMResponse.java | 50 ++++++++++++++ .../ai/provider/openai/OpenAiLLMRequest.java | 3 +- .../ai/provider/openai/OpenAiLLMResponse.java | 69 ------------------- .../openai/OpenAiLLMResponsesAPIResponse.java | 52 ++++++++++++++ 10 files changed, 133 insertions(+), 77 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponse.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index 5af30bc600..820018908c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -74,8 +74,11 @@ public Outcome handleRequest(Exchange exc) { log.debug("User: {}", user); if (exc.getRequest().isPOSTRequest()) { inputTokens = aiReq.estimateInputTokens(); + log.debug("Estimated input tokens: {}", inputTokens); var remaining = store.checkLimit(user, inputTokens, maxOutputTokens); + log.debug("Remaining tokens: {}", remaining); if (remaining <= 0) { + log.info("Token limit exceeded: {}/{}", inputTokens, maxOutputTokens); exc.setResponse(tokenLimitExceeded()); return RETURN; } @@ -143,7 +146,7 @@ public void setAiStore(AiApiStore store) { @Override public String getDisplayName() { - return "OpenAI API"; + return "LLM Gateway"; } public int getMaxOutputTokens() { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java index ab2c6fab1a..4e54f26b64 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java @@ -26,9 +26,9 @@ public ResponsesApiEvent(JsonNode json) { ObjectNode on = (ObjectNode) item; if ("function_call".equals(on.path("type").asText())) { - log.info("Function call: {} with {}", + log.info("Function call: {} with {} params", on.path("name").asText(), - on.path("arguments").asText()); + on.path("arguments").size()); } } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java index a28b797d05..98418f8e5a 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java @@ -33,6 +33,13 @@ public List getTools() { if (tools == null) return emptyList(); return tools.valueStream().map(n -> { + String type; + if (n.has("type")) { + type = n.get("type").asText(); + if (!"function".equals(type)) + return null; + } + // Chat completion if (n.has("function")) { return n.get("function").get("name").asText(); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java index d1550af0d2..de54298fd1 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java @@ -60,8 +60,10 @@ protected void processChunk(Chunk chunk, SSEParser parser) { terminal.ifPresent(event -> { // Terminal of old chat completion API - if ("[DONE]".equals(event.data())) + if ("[DONE]".equals(event.data())) { + postProcessor.accept(AbstractLLMResponse.this); return; + } json = JsonUtil.getJsonObject(event.data()) .orElse(JsonNodeFactory.instance.objectNode() .put("error", "No JSON object response from model.")); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java index c2f3106a53..a99b04e2bb 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java @@ -27,7 +27,10 @@ public static MessageDelta from(ObjectNode on) { md.cacheCreationInputTokens = u.path("cache_creation_input_tokens").asInt(0); md.cacheReadInputTokens = u.path("cache_read_input_tokens").asInt(0); - md.usage = new Usage(md.inputTokens, md.outputTokens, md.inputTokens + md.outputTokens); + // Cache tokens (cache_creation_input_tokens and cache_read_input_tokens) are billable according to Claude's pricing model + int effectiveInputTokens = md.inputTokens + md.cacheCreationInputTokens + md.cacheReadInputTokens; + md.usage = new Usage(effectiveInputTokens,md.outputTokens, effectiveInputTokens + md.outputTokens); + } return md; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java index d92e0ba254..43c8ce9223 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java @@ -17,7 +17,14 @@ public LLMRequest getLLMRequest(Exchange exchange) { } @Override - public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { - return new OpenAiLLMResponse(request, postProcessor); + public LLMResponse getLLMResponse(Exchange exchange, Consumer postProcessor) { + if (isResponsesApi(exchange)) { + return new OpenAiLLMResponsesAPIResponse(exchange,postProcessor); + } + return new OpenAiChatCompletionsLLMResponse(exchange, postProcessor); + } + + static boolean isResponsesApi(Exchange exchange) { + return exchange.getRequest().getUri().startsWith("/v1/responses"); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java new file mode 100644 index 0000000000..112e8df411 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java @@ -0,0 +1,50 @@ +package com.predic8.membrane.core.interceptor.ai.provider.openai; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractLLMEvent; +import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; +import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.util.http.SSEParser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Set; +import java.util.function.Consumer; + +public class OpenAiChatCompletionsLLMResponse extends AbstractLLMResponse { + + private static final Logger log = LoggerFactory.getLogger(OpenAiChatCompletionsLLMResponse.class); + + public OpenAiChatCompletionsLLMResponse(Exchange exchange, Consumer postProcessor) { + super(exchange, postProcessor); + } + + @Override + public Usage getUsage() { + + var usage = json.path("usage"); + + var inputTokens = usage.path("prompt_tokens").asInt(0); + var outputTokens = usage.path("completion_tokens").asInt(0); + var totalTokens = usage.path("total_tokens").asInt(inputTokens + outputTokens); + + return new Usage( + inputTokens, + outputTokens, + totalTokens + ); + } + + @Override + public Set getTerminalEvents() { + return Set.of("[DONE]"); + } + + @Override + public void process(SSEParser.SSEEvent e) { + log.debug("Data: {}", e.data()); + var event = AbstractLLMEvent.create(e); + System.out.println(event); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMRequest.java index eafe6cd0f7..784de9a983 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMRequest.java @@ -16,7 +16,8 @@ public OpenAiLLMRequest(Exchange exchange) { // Make sure that when streaming is enabled, the usage is included in the response. if (json.path("stream").asBoolean(false)) { if (isChatCompletionsRequest(exchange)) { - json.putObject("stream_options").put("include_usage", true); + var streamOptions = json.withObject("/stream_options"); + streamOptions.put("include_usage", true); } } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponse.java deleted file mode 100644 index 76e4c65c9b..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponse.java +++ /dev/null @@ -1,69 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider.openai; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractLLMEvent; -import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; -import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; -import com.predic8.membrane.core.interceptor.ai.store.Usage; -import com.predic8.membrane.core.util.http.SSEParser; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Set; -import java.util.function.Consumer; - -public class OpenAiLLMResponse extends AbstractLLMResponse { - - private static final Logger log = LoggerFactory.getLogger(OpenAiLLMResponse.class); - - public OpenAiLLMResponse(Exchange exchange, Consumer postProcessor) { - super(exchange,postProcessor); - } - - @Override - public Usage getUsage() { - - int inputTokens = 0; - int outputTokens = 0; - int totalTokens = 0; - - // Responses API - if (!json.path("response").isMissingNode()) { - var usage = json.path("response").path("usage"); - - getInputTokens(usage); - getOutputTokens(usage); - usage.path("total_tokens").asInt(inputTokens + outputTokens); - } else { - // Older chat completions API - inputTokens = json.path("usage").path("prompt_tokens").asInt(0); - outputTokens = json.path("usage").path("completion_tokens").asInt(0); - totalTokens = json.path("total_tokens").asInt(inputTokens + outputTokens); - - } - - return new Usage( - inputTokens, - outputTokens, - totalTokens - ); - } - - @Override - public Set getTerminalEvents() { - return Set.of("response.completed","response.incompleted"); - } - - @Override - public void process(SSEParser.SSEEvent e) { - log.debug("Event: {}", e.name()); - log.debug("Data: {}", e.data()); - var event = AbstractLLMEvent.create(e); - System.out.println(event); - - var json = event.getJson(); - if (!json.path("usage").isNull()) { - - } - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java new file mode 100644 index 0000000000..9825fa712d --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java @@ -0,0 +1,52 @@ +package com.predic8.membrane.core.interceptor.ai.provider.openai; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.AbstractLLMEvent; +import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; +import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.util.http.SSEParser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Set; +import java.util.function.Consumer; + +public class OpenAiLLMResponsesAPIResponse extends AbstractLLMResponse { + + private static final Logger log = LoggerFactory.getLogger(OpenAiLLMResponsesAPIResponse.class); + + public OpenAiLLMResponsesAPIResponse(Exchange exchange, Consumer postProcessor) { + super(exchange, postProcessor); + } + + @Override + public Usage getUsage() { + + var usage = json.path("usage"); + + // For streamed response.completed events + if (usage.isMissingNode() || usage.isNull()) { + usage = json.path("response").path("usage"); + } + + var inputTokens = getInputTokens(usage); + var outputTokens = getOutputTokens(usage); + var totalTokens = usage.path("total_tokens").asInt(inputTokens + outputTokens); + return new Usage(inputTokens, outputTokens, totalTokens); + + } + + @Override + public Set getTerminalEvents() { + return Set.of("response.completed", "response.incomplete"); + } + + @Override + public void process(SSEParser.SSEEvent e) { + log.debug("Event: {}", e.name()); + log.debug("Data: {}", e.data()); + var event = AbstractLLMEvent.create(e); + System.out.println(event); + } +} From 2bfca01d8513bccfd2ddc157cf3c6e8a74cbbbcc Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 19 May 2026 11:25:34 +0200 Subject: [PATCH 17/43] feat: add logging for non-JSON requests in `AbstractLLMRequest` - Introduced SLF4J logger to `AbstractLLMRequest`. - Added log message for handling non-JSON requests. - Improved exception handling with informative runtime error. --- .../interceptor/ai/provider/AbstractLLMRequest.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java index 98418f8e5a..b502f9e79f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java @@ -4,6 +4,8 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.util.json.JsonUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.List; @@ -12,6 +14,8 @@ public abstract class AbstractLLMRequest extends AbstractLLMMessage implements LLMRequest { + private static final Logger log = LoggerFactory.getLogger(AbstractLLMRequest.class); + public static final String BEARER_PREFIX = "Bearer"; protected ObjectNode json; @@ -21,10 +25,9 @@ public AbstractLLMRequest(Exchange exchange) { if (exchange.getRequest().isJSON()) { json = JsonUtil.getJsonObject(exchange.getRequest()).orElseThrow(() -> new RuntimeException("No JSON object request.")); - - if (json.path("tools").isArray()) { - //System.out.println("Tools: " + json.path("tools")); - } + } else { + log.info("Request is not JSON:"); + throw new RuntimeException("Request is not JSON."); } } From 3e9a01439292306644085b26e11e7c78adb5ee81 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 19 May 2026 11:34:24 +0200 Subject: [PATCH 18/43] refactor: replace `System.out.println` with proper debug logging in OpenAI response handlers --- .../ai/provider/openai/OpenAiChatCompletionsLLMResponse.java | 2 +- .../ai/provider/openai/OpenAiLLMResponsesAPIResponse.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java index 112e8df411..cb226bdf28 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java @@ -45,6 +45,6 @@ public Set getTerminalEvents() { public void process(SSEParser.SSEEvent e) { log.debug("Data: {}", e.data()); var event = AbstractLLMEvent.create(e); - System.out.println(event); + log.debug("Event: {}", event); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java index 9825fa712d..4cff5b313f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java @@ -47,6 +47,6 @@ public void process(SSEParser.SSEEvent e) { log.debug("Event: {}", e.name()); log.debug("Data: {}", e.data()); var event = AbstractLLMEvent.create(e); - System.out.println(event); + log.debug("Event: {}", event); } } From 68705e1ee1a68a288d632432335ce1778c5f9130 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 20 May 2026 10:22:15 +0200 Subject: [PATCH 19/43] feat: introduce Basic LLM Gateway tutorial and enhance OpenAI LLM handling - Added `10-Basic-LLM-Gateway.yaml` tutorial for setting up a basic LLM gateway. - Introduced new classes `AbstractOpenAiLLMRequest` and `OpenAiLLMChatCompletionsRequest` for modularizing token estimation and API handling. - Improved token usage tracking with client-requested max output tokens. - Added detailed inline documentation across AI-related classes for better maintainability. - Updated `membrane.cmd` and `membrane.sh` for enhanced gateway setup. --- .../interceptor/ai/LLMGatewayInterceptor.java | 92 ++++++++++++++++--- .../interceptor/ai/ResponsesApiEvent.java | 21 ++++- .../ai/provider/AbstractLLMRequest.java | 24 +---- .../ai/provider/AbstractLLMResponse.java | 13 +-- .../interceptor/ai/provider/LLMRequest.java | 6 ++ .../ai/provider/claude/ClaudeLLMRequest.java | 5 + .../ai/provider/claude/ClaudeLLMResponse.java | 1 + .../ai/provider/claude/ClaudeProvider.java | 4 + .../ai/provider/google/GoogleLLMRequest.java | 7 ++ .../ai/provider/google/GoogleProvider.java | 4 + ...est.java => AbstractOpenAiLLMRequest.java} | 32 +------ .../ai/provider/openai/OpenAIProvider.java | 17 +++- .../OpenAiChatCompletionsLLMResponse.java | 5 + .../OpenAiLLMChatCompletionsRequest.java | 51 ++++++++++ .../openai/OpenAiLLMResponsesRequest.java | 41 +++++++++ ...e.java => OpenAiLLMResponsesResponse.java} | 15 ++- .../core/interceptor/ai/store/AiApiLimit.java | 13 +++ .../core/interceptor/ai/store/AiApiUser.java | 4 + .../ai/store/JDBCAiApiUsageStore.java | 6 +- .../interceptor/ai/store/NoAiApiLimit.java | 3 + .../ai/store/SimpleAiApiStore.java | 26 +++++- .../membrane/core/util/json/JsonUtil.java | 4 +- .../ai/llm-gateway/10-Basic-LLM-Gateway.yaml | 28 ++++++ .../tutorials/ai/llm-gateway/max-input.json | 4 + .../tutorials/ai/llm-gateway/max-output.json | 5 + .../tutorials/ai/llm-gateway/membrane.cmd | 24 +++++ .../tutorials/ai/llm-gateway/membrane.sh | 21 +++++ .../tutorials/ai/llm-gateway/simple.json | 4 + 28 files changed, 390 insertions(+), 90 deletions(-) rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/{OpenAiLLMRequest.java => AbstractOpenAiLLMRequest.java} (67%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java rename core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/{OpenAiLLMResponsesAPIResponse.java => OpenAiLLMResponsesResponse.java} (72%) create mode 100644 distribution/tutorials/ai/llm-gateway/10-Basic-LLM-Gateway.yaml create mode 100644 distribution/tutorials/ai/llm-gateway/max-input.json create mode 100644 distribution/tutorials/ai/llm-gateway/max-output.json create mode 100644 distribution/tutorials/ai/llm-gateway/membrane.cmd create mode 100755 distribution/tutorials/ai/llm-gateway/membrane.sh create mode 100644 distribution/tutorials/ai/llm-gateway/simple.json diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index 820018908c..56ef011243 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -21,7 +21,20 @@ import static com.predic8.membrane.core.interceptor.ai.LLMApiUtil.*; import static com.predic8.membrane.core.util.json.JsonUtil.setJsonBody; -@MCElement(name = "aiGateway") +/* + * @description

+ * API Gateway for Large Language Models (LLMs). + *

+ * Features: + *
    + *
  • Sharing an API key between multiple users
  • + *
  • Enforcing token limits
  • + *
  • Logging LLM usage
  • + *
+ *

+ * @topic 10. AI + */ +@MCElement(name = "llmGateway") public class LLMGatewayInterceptor extends AbstractInterceptor { private static final Logger log = LoggerFactory.getLogger(LLMGatewayInterceptor.class); @@ -50,7 +63,7 @@ public Outcome handleRequest(Exchange exc) { try { aiReq = provider.getLLMRequest(exc); } catch (Exception e) { - user(router.getConfiguration().isProduction(),"AI Gateway") + user(router.getConfiguration().isProduction(), "AI Gateway") .title("Invalid request") .detail("Error parsing request: " + e.getMessage()) .buildAndSetResponse(exc); @@ -62,19 +75,22 @@ public Outcome handleRequest(Exchange exc) { return CONTINUE; } - long inputTokens = 0; - + AiApiUser user = null; if (store != null) { var opt = store.getUser(aiReq.getApiKey()); if (opt.isEmpty()) { exc.setResponse(authenticationFailed()); return RETURN; } - var user = opt.get(); + user = opt.get(); log.debug("User: {}", user); - if (exc.getRequest().isPOSTRequest()) { - inputTokens = aiReq.estimateInputTokens(); - log.debug("Estimated input tokens: {}", inputTokens); + } + + long inputTokens = 0; + if (exc.getRequest().isPOSTRequest()) { + inputTokens = aiReq.estimateInputTokens(); + log.debug("Estimated input tokens: {}", inputTokens); + if (store != null) { var remaining = store.checkLimit(user, inputTokens, maxOutputTokens); log.debug("Remaining tokens: {}", remaining); if (remaining <= 0) { @@ -83,18 +99,27 @@ public Outcome handleRequest(Exchange exc) { return RETURN; } } - exc.setProperty(MEMBRANE_AI_USER, user); } + exc.setProperty(MEMBRANE_AI_USER, user); - // TODO if no apiKey in config => use key from client - aiReq.setApiKey(apiKey); - if (maxOutputTokens != 0) { + // If APIKey is specified, use that for the LLM. Overwrites keys from the client + if (apiKey != null) { + aiReq.setApiKey(apiKey); + } + + log.debug("max-tokens from client: {}", aiReq.getModel()); + + var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); + + if (maxOutputTokens != 0 && requestedMaxOutputTokens > maxOutputTokens) { + log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.",requestedMaxOutputTokens, maxOutputTokens); aiReq.setMaxOutputTokens(maxOutputTokens); } if (maxInputTokens != 0) { if (inputTokens > maxInputTokens) { + log.info("Input tokens {} exceed the limit of {}.",inputTokens, maxInputTokens); exc.setResponse(contextLengthExceeded(maxInputTokens, inputTokens)); return RETURN; } @@ -108,7 +133,7 @@ public Outcome handleRequest(Exchange exc) { } } - log.debug("Tools: {}", aiReq.getTools()); + log.debug("Agent provides the tools: {}", aiReq.getTools()); setJsonBody(exc.getRequest(), aiReq.getJson()); return CONTINUE; @@ -117,9 +142,15 @@ public Outcome handleRequest(Exchange exc) { @Override public Outcome handleResponse(Exchange exc) { - var aiRes = provider.getLLMResponse(exc, res -> { + provider.getLLMResponse(exc, res -> { + var user = exc.getProperty(MEMBRANE_AI_USER, AiApiUser.class); + if (log.isInfoEnabled() && user != null) { + log.debug("Token usage of user {}: {}", user, res.getUsage()); + } else { + log.info("Token usage: {}", res.getUsage()); + } if (store != null) { - store.store(exc.getProperty(MEMBRANE_AI_USER, AiApiUser.class), res.getUsage()); + store.store(user, res.getUsage()); } }); @@ -130,6 +161,10 @@ public String getApiKey() { return apiKey; } + /** + * @param apiKey + * @description API key for the LLM provider. Specify here the API key from OpenAI or Anthropic. + */ @MCAttribute public void setApiKey(String apiKey) { this.apiKey = apiKey; @@ -139,6 +174,12 @@ public AiApiStore getAiStore() { return store; } + /** + * @param store Store for API keys and usage statistics + * @description The LLM Gateway can operate stateless and statefully. For stateful operation, specify an AiApiStore. + * A store is needed for user authentication at the gateway. + * The gateway will use the store to enforce token limits and log usage statistics. + */ @MCChildElement(allowForeign = true, order = 10) public void setAiStore(AiApiStore store) { this.store = store; @@ -153,6 +194,12 @@ public int getMaxOutputTokens() { return maxOutputTokens; } + /** + * @param maxOutputTokens + * @description Maximum number of tokens the LLM should use to generate a response. This is just a hint that the gateway + * sends to the LLM provider. The provider may use a different limit. + * @default 0 (unlimited) + */ @MCAttribute public void setMaxOutputTokens(int maxOutputTokens) { this.maxOutputTokens = maxOutputTokens; @@ -162,6 +209,11 @@ public int getMaxInputTokens() { return maxInputTokens; } + /** + * @param maxInputTokens + * @description Restricts token usage for the input. The size of the input is estimated by gateway based on the request size. + * Actual token usage may be deviate from this value. + */ @MCAttribute public void setMaxInputTokens(int maxInputTokens) { this.maxInputTokens = maxInputTokens; @@ -171,6 +223,11 @@ public List getModels() { return models; } + /** + * @param models + * @desciptions Restricts the models that can be used by the gateway. + * @default null (no restriction) + */ @MCAttribute public void setModels(List models) { this.models = models; @@ -180,6 +237,11 @@ public LLMProvider getProvider() { return provider; } + /** + * @param provider The LLM provider to use. + * @description The LLM provider to use. Currently, OpenAI, Anthropic and Gemini are supported. + * The provider determines the API used to talk to the LLM. The provider can be different as long as the API is supported. + */ @MCChildElement(allowForeign = true) public void setProvider(LLMProvider provider) { this.provider = provider; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java index 4e54f26b64..af2a351dc6 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java @@ -20,15 +20,19 @@ public ResponsesApiEvent(JsonNode json) { if ("response.output_item.done".equals(type)) { - JsonNode item = json.path("item"); + var item = json.path("item"); if (item.isObject()) { - ObjectNode on = (ObjectNode) item; + var on = (ObjectNode) item; if ("function_call".equals(on.path("type").asText())) { - log.info("Function call: {} with {} params", - on.path("name").asText(), - on.path("arguments").size()); + if (log.isDebugEnabled()) { + log.debug("Function call: {} with params {}", + on.path("name").asText(), + on.path("arguments").asText()); + } else { + log.info("Function call: {}", on.path("name")); + } } } } @@ -38,4 +42,11 @@ public ResponsesApiEvent(JsonNode json) { public String getType() { return type; } + + @Override + public String toString() { + return "ResponsesApiEvent{" + + "type='" + type + '\'' + + '}'; + } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java index b502f9e79f..95ecc7a77e 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java @@ -7,10 +7,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collections; import java.util.List; import static com.predic8.membrane.core.http.Header.AUTHORIZATION; -import static java.util.Collections.emptyList; public abstract class AbstractLLMRequest extends AbstractLLMMessage implements LLMRequest { @@ -24,7 +24,7 @@ public AbstractLLMRequest(Exchange exchange) { super(exchange); if (exchange.getRequest().isJSON()) { - json = JsonUtil.getJsonObject(exchange.getRequest()).orElseThrow(() -> new RuntimeException("No JSON object request.")); + json = JsonUtil.getJsonObject(exchange.getRequest()).orElseThrow(() -> new RuntimeException("Cannot parse input as JSON message.")); } else { log.info("Request is not JSON:"); throw new RuntimeException("Request is not JSON."); @@ -32,26 +32,10 @@ public AbstractLLMRequest(Exchange exchange) { } public List getTools() { - var tools = getToolsNode(); - if (tools == null) - return emptyList(); - return tools.valueStream().map(n -> { - String type; - if (n.has("type")) { - type = n.get("type").asText(); - if (!"function".equals(type)) - return null; - } - - // Chat completion - if (n.has("function")) { - return n.get("function").get("name").asText(); - } - return n.get("name").asText(); - }).toList(); + return Collections.emptyList(); } - private ArrayNode getToolsNode() { + protected ArrayNode getToolsNode() { if (json == null) return null; if (json.path("tools").isArray()) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java index de54298fd1..f6e0c12c21 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java @@ -59,20 +59,13 @@ protected void processChunk(Chunk chunk, SSEParser parser) { events.forEach(this::process); terminal.ifPresent(event -> { - // Terminal of old chat completion API - if ("[DONE]".equals(event.data())) { - postProcessor.accept(AbstractLLMResponse.this); - return; - } - json = JsonUtil.getJsonObject(event.data()) - .orElse(JsonNodeFactory.instance.objectNode() - .put("error", "No JSON object response from model.")); - - // All is read, call postProcessor + processTerminalEvent(event); postProcessor.accept(AbstractLLMResponse.this); }); } + protected void processTerminalEvent(SSEParser.SSEEvent terminal) {}; + @Override public boolean isError() { return json.get("error") != null && !json.get("error").isNull(); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java index 9b0c577ca9..83ad121a89 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java @@ -12,6 +12,12 @@ public interface LLMRequest { void setApiKey(String apiKey); + /** + * The max number of tokens that the model is allowed to generate as specified by the client. + * @return The max number of tokens that the model is allowed to generate. + */ + long getRequestedMaxOutputTokens(); + void setMaxOutputTokens(int maxOutputTokens); long estimateInputTokens(); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java index 236f74a3ae..2a1151e855 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java @@ -76,6 +76,11 @@ private void disableThinking() { thinking.put("type", "disabled"); } + @Override + public long getRequestedMaxOutputTokens() { + return json.path("max_tokens").asLong(0); + } + @Override public String getApiKey() { return exchange.getRequest().getHeader().getFirstValue(X_API_KEY); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java index 0e44dc8c4b..dbdbbe3de4 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java @@ -61,4 +61,5 @@ public void process(SSEEvent event) { public Usage getUsage() { return usage; } + } \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java index 471440f214..6be0db1566 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java @@ -8,6 +8,10 @@ import java.util.function.Consumer; +/** + * @description Anthroic Claude provider configuration + * Use to configure a LLM gateway to use the anthropic API + */ @MCElement( name="claude") public class ClaudeProvider implements LLMProvider { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java index 19ce54f6b2..e8a392bdc4 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java @@ -50,6 +50,13 @@ public void setApiKey(String apiKey) { exchange.getRequest().getHeader().add(X_GOOG_API_KEY, apiKey); } + @Override + public long getRequestedMaxOutputTokens() { + return json.path("generationConfig") + .path("maxOutputTokens") + .asLong(0); + } + public long estimateInputTokens() { if (json == null || json.isNull()) { return 0; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java index 9a7d27e31d..729c97df98 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java @@ -8,6 +8,10 @@ import java.util.function.Consumer; +/** + * @description Google AI provider configuration + * Use to configure a LLM gateway to use the Google LLM API + */ @MCElement( name="google",id = "google-ai-provider") public class GoogleProvider implements LLMProvider { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/AbstractOpenAiLLMRequest.java similarity index 67% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/AbstractOpenAiLLMRequest.java index 784de9a983..7b8a76e4d1 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/AbstractOpenAiLLMRequest.java @@ -4,22 +4,10 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMRequest; -public class OpenAiLLMRequest extends AbstractLLMRequest { +public abstract class AbstractOpenAiLLMRequest extends AbstractLLMRequest { - public OpenAiLLMRequest(Exchange exchange) { + public AbstractOpenAiLLMRequest(Exchange exchange) { super(exchange); - - if (json == null) { - return; - } - - // Make sure that when streaming is enabled, the usage is included in the response. - if (json.path("stream").asBoolean(false)) { - if (isChatCompletionsRequest(exchange)) { - var streamOptions = json.withObject("/stream_options"); - streamOptions.put("include_usage", true); - } - } } @Override @@ -40,18 +28,6 @@ public long estimateInputTokens() { return Math.max(1, Math.round(chars / 4.0 * 1.15)); } - @Override - public void setMaxOutputTokens(int maxOutputTokens) { - // OpenAI deprecated max_tokens for newer models (o1, o3, gpt-5.x) in - // favor of max_completion_tokens. Older models still accept max_tokens. - if (api == API.NORMAL) { - json.put("max_output_tokens", maxOutputTokens); - } - if (api == API.COMPLETIONS) { - json.put("max_completion_tokens", maxOutputTokens); - } - } - private long estimateChatCompletitions() { long chars = 0; // Chat Completions API @@ -108,8 +84,4 @@ private long countJsonSize(JsonNode node) { } return node.toString().length(); } - - private boolean isChatCompletionsRequest(Exchange exchange) { - return exchange.getRequest().getUri().contains("/chat/completions"); - } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java index 43c8ce9223..7e16bdbd54 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java @@ -8,18 +8,29 @@ import java.util.function.Consumer; +/** + * @description OpenAI provider configuration + * Use to configure a LLM gateway to use the OpenAI API + */ @MCElement( name="openai") public class OpenAIProvider implements LLMProvider { + boolean isResponsesAPI; + @Override public LLMRequest getLLMRequest(Exchange exchange) { - return new OpenAiLLMRequest(exchange); + isResponsesAPI = isResponsesApi(exchange); + if (isResponsesAPI) { + return new OpenAiLLMResponsesRequest(exchange); + } + + return new OpenAiLLMChatCompletionsRequest(exchange); } @Override public LLMResponse getLLMResponse(Exchange exchange, Consumer postProcessor) { - if (isResponsesApi(exchange)) { - return new OpenAiLLMResponsesAPIResponse(exchange,postProcessor); + if (isResponsesAPI) { + return new OpenAiLLMResponsesResponse(exchange,postProcessor); } return new OpenAiChatCompletionsLLMResponse(exchange, postProcessor); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java index cb226bdf28..bb0a206a8c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java @@ -41,6 +41,11 @@ public Set getTerminalEvents() { return Set.of("[DONE]"); } + @Override + protected void processTerminalEvent(SSEParser.SSEEvent terminal) { + postProcessor.accept(OpenAiChatCompletionsLLMResponse.this); + } + @Override public void process(SSEParser.SSEEvent e) { log.debug("Data: {}", e.data()); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java new file mode 100644 index 0000000000..69515d796f --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java @@ -0,0 +1,51 @@ +package com.predic8.membrane.core.interceptor.ai.provider.openai; + +import com.predic8.membrane.core.exchange.Exchange; + +import java.util.List; + +import static java.util.Collections.emptyList; + +public class OpenAiLLMChatCompletionsRequest extends AbstractOpenAiLLMRequest { + + public OpenAiLLMChatCompletionsRequest(Exchange exchange) { + super(exchange); + + if (json == null) { + return; + } + + // Make sure that when streaming is enabled, the usage is included in the response. + if (json.path("stream").asBoolean(false)) { + var streamOptions = json.withObject("/stream_options"); + streamOptions.put("include_usage", true); + } + } + + @Override + public void setMaxOutputTokens(int maxOutputTokens) { + json.put("max_completion_tokens", maxOutputTokens); + } + + public List getTools() { + var tools = getToolsNode(); + if (tools == null) + return emptyList(); + return tools.valueStream().map(n -> { + String type; + if (n.has("type")) { + type = n.get("type").asText(); + if (!"function".equals(type)) + return null; + } + + return n.get("function").get("name").asText(); + }).toList(); + } + + @Override + public long getRequestedMaxOutputTokens() { + return json.path("max_tokens").asLong(0); + } + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java new file mode 100644 index 0000000000..afb2a05c62 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java @@ -0,0 +1,41 @@ +package com.predic8.membrane.core.interceptor.ai.provider.openai; + +import com.predic8.membrane.core.exchange.Exchange; + +import java.util.List; + +import static java.util.Collections.emptyList; + +public class OpenAiLLMResponsesRequest extends AbstractOpenAiLLMRequest { + + public OpenAiLLMResponsesRequest(Exchange exchange) { + super(exchange); + } + + public List getTools() { + var tools = getToolsNode(); + if (tools == null) + return emptyList(); + return tools.valueStream().map(n -> { + String type; + if (n.has("type")) { + type = n.get("type").asText(); + if (!"function".equals(type)) + return null; + } + return n.get("name").asText(); + }).toList(); + } + + @Override + public long getRequestedMaxOutputTokens() { + if (json.has("max_output_tokens")) + return json.get("max_output_tokens").asLong(); + return 0; + } + + @Override + public void setMaxOutputTokens(int maxOutputTokens) { + json.put("max_output_tokens", maxOutputTokens); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesResponse.java similarity index 72% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesResponse.java index 4cff5b313f..67e836ebd6 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesAPIResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesResponse.java @@ -1,22 +1,24 @@ package com.predic8.membrane.core.interceptor.ai.provider.openai; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.ai.AbstractLLMEvent; import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; import com.predic8.membrane.core.interceptor.ai.store.Usage; import com.predic8.membrane.core.util.http.SSEParser; +import com.predic8.membrane.core.util.json.JsonUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Set; import java.util.function.Consumer; -public class OpenAiLLMResponsesAPIResponse extends AbstractLLMResponse { +public class OpenAiLLMResponsesResponse extends AbstractLLMResponse { - private static final Logger log = LoggerFactory.getLogger(OpenAiLLMResponsesAPIResponse.class); + private static final Logger log = LoggerFactory.getLogger(OpenAiLLMResponsesResponse.class); - public OpenAiLLMResponsesAPIResponse(Exchange exchange, Consumer postProcessor) { + public OpenAiLLMResponsesResponse(Exchange exchange, Consumer postProcessor) { super(exchange, postProcessor); } @@ -42,6 +44,13 @@ public Set getTerminalEvents() { return Set.of("response.completed", "response.incomplete"); } + @Override + protected void processTerminalEvent(SSEParser.SSEEvent terminal) { + json = JsonUtil.getJsonObject(terminal.data()) + .orElse(JsonNodeFactory.instance.objectNode() + .put("error", "No JSON object response from model.")); + } + @Override public void process(SSEParser.SSEEvent e) { log.debug("Event: {}", e.name()); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java index eb82874372..a04135d0b6 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java @@ -11,6 +11,9 @@ import static java.time.Instant.now; +/** + * @description Limits the number of tokens that can be used for a specific API. + */ @MCElement(name = "limit", component = false, id = "ai-api-limit") public class AiApiLimit { @@ -57,6 +60,11 @@ public int getMaxTokens() { return maxTokens; } + /** + * @description Maximum number of tokens that can be used within a period. + * @default 0 (no limit) + * @param maxTokens Maximum number of tokens + */ @MCAttribute public void setMaxTokens(int maxTokens) { this.maxTokens = maxTokens; @@ -66,6 +74,11 @@ public int getPeriod() { return period; } + /** + * @description Period after which the token limit resets. + * @default 0 (no limit) + * @param period in seconds + */ @MCAttribute public void setPeriod(int period) { synchronized (lock) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java index a867bfddcc..91bdf153fb 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java @@ -13,6 +13,10 @@ public String getName() { return name; } + /** + * @description Name of the API user, group or cost center. + * @param name of the user + */ @MCAttribute() public void setName(String name) { this.name = name; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java index acff4625c1..d4ad25e009 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java @@ -7,12 +7,16 @@ import java.sql.SQLException; import java.util.Optional; +/** + * @description Stores AI API usage in a database (experimental). + */ @MCElement(name = "jdbcAiApiUsageStore") public class JDBCAiApiUsageStore extends AbstractJdbcSupport implements AiApiStore { + // @TODO GENERATED ALWAYS AS IDENTITY is PostgreSQL specific private static final String CREATE_TABLE_SQL = """ CREATE TABLE IF NOT EXISTS ai_api_usage ( - id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, // @TODO GENERATED ALWAYS AS IDENTITY is PostgreSQL specific + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, username VARCHAR(255) NOT NULL, input_tokens INT NOT NULL, output_tokens INT NOT NULL, diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java index 63cdaa6d40..bf2a74a464 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java @@ -1,5 +1,8 @@ package com.predic8.membrane.core.interceptor.ai.store; +/** + * @description Store that does not limit the number of AI API calls (experimental). + */ public class NoAiApiLimit extends AiApiLimit{ @Override diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index ff99faa98a..5b1697d849 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -9,6 +9,10 @@ import java.util.List; import java.util.Optional; +/** + * @description Simple store for the LLM Gateway that stores limits in memory. Users and keys can + * be configured in the configuration file. + */ @MCElement(name="simpleStore",component = false, id="simple-ai-api-store") public class SimpleAiApiStore implements AiApiStore { @@ -17,9 +21,12 @@ public class SimpleAiApiStore implements AiApiStore { private List users = new ArrayList<>(); private AiApiLimit limit = new NoAiApiLimit(); + private boolean logUsage = true; + @Override public void store(AiApiUser user, Usage usage) { - log.debug("Usage: {}", usage); + if (logUsage) + log.info("user: {} {}",user.getName(),usage.toString()); limit.addTokens(usage.totalTokens()); } @@ -36,6 +43,10 @@ public long checkLimit(AiApiUser user, long inputTokens, long outputTokens) { return limit.checkLimit(inputTokens + outputTokens); } + /** + * List of users that can be used for authentication. + * @param users User list + */ @MCChildElement(allowForeign = true,order = 10) public void setUsers(List users) { this.users = users; @@ -49,9 +60,22 @@ public AiApiLimit getLimit() { return limit; } + /** + * @description The limit of tokens that can be used for each user. + * @default 0 (no limit) + * @param limit + */ @MCChildElement(allowForeign = true) public void setLimit(AiApiLimit limit) { this.limit = limit; } + + public boolean isLogUsage() { + return logUsage; + } + + public void setLogUsage(boolean logUsage) { + this.logUsage = logUsage; + } } diff --git a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java index b6c99902e5..515efe0803 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java @@ -117,9 +117,9 @@ private static Optional getJsonObjectFromSteam(InputStream obj) { if (node instanceof ObjectNode on) { return Optional.of(on); } - log.debug("Expected JSON Object but got: {}",node.getNodeType()); + log.info("Expected JSON Object but got: {}",node.getNodeType()); } catch (Exception e) { - log.debug("Error reading JSON: {}", e.getMessage()); + log.info("Error reading JSON: {}", e.getMessage()); } return empty(); } diff --git a/distribution/tutorials/ai/llm-gateway/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/10-Basic-LLM-Gateway.yaml new file mode 100644 index 0000000000..edf4e5b368 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/10-Basic-LLM-Gateway.yaml @@ -0,0 +1,28 @@ +# yaml-language-server: $schema=https://www.membrane-api.io/v7.2.1.json +# +# Tutorial: Basic LLM Gateway +# +# Replace <> with your OpenAI API key. +# +# 1. Hello World +# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @simple.json http://localhost:2000/v1/responses +# +# 2. Exceed the input token limit +# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-input.json http://localhost:2000/v1/responses +# => Returns an error because the request exceeds maxInputTokens. +# +# 3. Exceed the output token limit +# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-output.json http://localhost:2000/v1/responses +# => Check the max_output_tokens field in the response. + +api: + port: 2000 + flow: + - llmGateway: + openai: {} + maxInputTokens: 100 + maxOutputTokens: 200 + - request: + - log: {} + target: + url: https://api.openai.com diff --git a/distribution/tutorials/ai/llm-gateway/max-input.json b/distribution/tutorials/ai/llm-gateway/max-input.json new file mode 100644 index 0000000000..e4b0e90985 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/max-input.json @@ -0,0 +1,4 @@ +{ + "model": "gpt-5-nano", + "input": "Who are you, where do you get your information from, how do you answer questions, why were you created, what kinds of problems can you solve, where do you go when you search for information, how do you decide what is important, what do you know about programming, science, history, languages, and technology, how do you explain difficult concepts to people, why do people use AI assistants, what happens when you do not know an answer, and why should someone trust the answers you provide?" +} \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/max-output.json b/distribution/tutorials/ai/llm-gateway/max-output.json new file mode 100644 index 0000000000..65a63165f2 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/max-output.json @@ -0,0 +1,5 @@ +{ + "model": "gpt-5-nano", + "input": "What is your name?", + "max_output_tokens": 500 +} \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/membrane.cmd b/distribution/tutorials/ai/llm-gateway/membrane.cmd new file mode 100644 index 0000000000..8d2d64e9cf --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/membrane.cmd @@ -0,0 +1,24 @@ +@echo off +setlocal EnableExtensions + +set "SCRIPT_DIR=%~dp0" +if "%SCRIPT_DIR:~-1%"=="\" set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%" + +set "dir=%SCRIPT_DIR%" + +:search_up +if exist "%dir%\LICENSE.txt" if exist "%dir%\scripts\run-membrane.cmd" goto found +for %%A in ("%dir%\..") do set "next=%%~fA" +if /I "%next%"=="%dir%" goto notfound +set "dir=%next%" +goto search_up + +:found +set "MEMBRANE_HOME=%dir%" +set "MEMBRANE_CALLER_DIR=%SCRIPT_DIR%" +call "%MEMBRANE_HOME%\scripts\run-membrane.cmd" %* +exit /b %ERRORLEVEL% + +:notfound +>&2 echo Could not locate Membrane root. Ensure directory structure is correct. +exit /b 1 diff --git a/distribution/tutorials/ai/llm-gateway/membrane.sh b/distribution/tutorials/ai/llm-gateway/membrane.sh new file mode 100755 index 0000000000..195dae51ec --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/membrane.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# Default: ./proxies.xml (next to this script); fallback -> $MEMBRANE_HOME/conf/proxies.xml +# JAVA_OPTS: relative -D paths are auto-resolved against $MEMBRANE_HOME (absolute/URI unchanged). +# Examples: +# export JAVA_OPTS='-Dlog4j.configurationFile=examples/logging/access/log4j2_access.xml' +# export JAVA_OPTS='-Dlog4j.configurationFile=/abs/path/log4j2.xml' + +SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd -P) + +dir="$SCRIPT_DIR" +while [ "$dir" != "/" ]; do + if [ -f "$dir/LICENSE.txt" ] && [ -f "$dir/scripts/run-membrane.sh" ]; then + export MEMBRANE_HOME="$dir" + export MEMBRANE_CALLER_DIR="$SCRIPT_DIR" + exec sh "$dir/scripts/run-membrane.sh" "$@" + fi + dir=$(dirname "$dir") +done + +echo "Could not locate Membrane root. Ensure directory structure is correct." >&2 +exit 1 \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/simple.json b/distribution/tutorials/ai/llm-gateway/simple.json new file mode 100644 index 0000000000..ab3c4b7bde --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/simple.json @@ -0,0 +1,4 @@ +{ + "model": "gpt-5-nano", + "input": "Who are you?" +} \ No newline at end of file From 27c644d4cb75a8abd8ba60252688e4c1eb02975d Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 20 May 2026 10:23:24 +0200 Subject: [PATCH 20/43] refactor: remove redundant semicolon in `processTerminalEvent` method --- .../core/interceptor/ai/provider/AbstractLLMResponse.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java index f6e0c12c21..4967d34af4 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java @@ -64,7 +64,7 @@ protected void processChunk(Chunk chunk, SSEParser parser) { }); } - protected void processTerminalEvent(SSEParser.SSEEvent terminal) {}; + protected void processTerminalEvent(SSEParser.SSEEvent terminal) {} @Override public boolean isError() { From 383856ae7eaf91a2d515a9a71c4d6911dc33ac31 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 20 May 2026 13:21:21 +0200 Subject: [PATCH 21/43] refactor: simplify logic for API response handling and token usage tracking - Removed redundant `isResponsesAPI` variable in `OpenAIProvider`. - Optimized tool extraction in `OpenAiLLMResponsesRequest` and `OpenAiLLMChatCompletionsRequest`. - Updated `AiApiLimit` to support unlimited tokens with `MAX_VALUE`. - Replaced `token` with `apiKey` in `AiApiUser` along with added `tokens` field. - Improved JSON parsing logic in `JsonUtil` with better exception handling and logging. - Adjusted output token parameter naming in multiple request classes for consistency. --- .../interceptor/ai/LLMGatewayInterceptor.java | 4 +-- .../ai/provider/google/GoogleLLMRequest.java | 2 +- .../ai/provider/openai/OpenAIProvider.java | 7 ++-- .../OpenAiLLMChatCompletionsRequest.java | 17 ++++------ .../openai/OpenAiLLMResponsesRequest.java | 16 ++++------ .../core/interceptor/ai/store/AiApiLimit.java | 12 ++++--- .../core/interceptor/ai/store/AiApiUser.java | 32 ++++++++++++++++--- .../membrane/core/util/json/JsonUtil.java | 23 ++++++++++--- 8 files changed, 70 insertions(+), 43 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index 56ef011243..c020a62b58 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -108,7 +108,7 @@ public Outcome handleRequest(Exchange exc) { aiReq.setApiKey(apiKey); } - log.debug("max-tokens from client: {}", aiReq.getModel()); + log.debug("Requested model: {}", aiReq.getModel()); var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); @@ -224,8 +224,8 @@ public List getModels() { } /** + * @desciption Restricts the models that can be used by the gateway. * @param models - * @desciptions Restricts the models that can be used by the gateway. * @default null (no restriction) */ @MCAttribute diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java index e8a392bdc4..aff431819d 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java @@ -113,7 +113,7 @@ private long countText(JsonNode node) { @Override public void setMaxOutputTokens(int maxOutputTokens) { - getGenerationConfig().put("max_output_tokens", maxOutputTokens); + getGenerationConfig().put("maxOutputTokens", maxOutputTokens); } private ObjectNode getGenerationConfig() { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java index 7e16bdbd54..01af8104a3 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java @@ -15,12 +15,9 @@ @MCElement( name="openai") public class OpenAIProvider implements LLMProvider { - boolean isResponsesAPI; - @Override public LLMRequest getLLMRequest(Exchange exchange) { - isResponsesAPI = isResponsesApi(exchange); - if (isResponsesAPI) { + if (isResponsesApi(exchange)) { return new OpenAiLLMResponsesRequest(exchange); } @@ -29,7 +26,7 @@ public LLMRequest getLLMRequest(Exchange exchange) { @Override public LLMResponse getLLMResponse(Exchange exchange, Consumer postProcessor) { - if (isResponsesAPI) { + if (isResponsesApi(exchange)) { return new OpenAiLLMResponsesResponse(exchange,postProcessor); } return new OpenAiChatCompletionsLLMResponse(exchange, postProcessor); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java index 69515d796f..5c57339682 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java @@ -31,21 +31,16 @@ public List getTools() { var tools = getToolsNode(); if (tools == null) return emptyList(); - return tools.valueStream().map(n -> { - String type; - if (n.has("type")) { - type = n.get("type").asText(); - if (!"function".equals(type)) - return null; - } - - return n.get("function").get("name").asText(); - }).toList(); + return tools.valueStream() + .filter(n -> "function".equals(n.path("type").asText(""))) + .map(n -> n.path("function").path("name").asText("")) + .filter(name -> !name.isEmpty()) + .toList(); } @Override public long getRequestedMaxOutputTokens() { - return json.path("max_tokens").asLong(0); + return json.path("max_completion_tokens").asLong(0); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java index afb2a05c62..e6547d7e25 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java @@ -16,15 +16,11 @@ public List getTools() { var tools = getToolsNode(); if (tools == null) return emptyList(); - return tools.valueStream().map(n -> { - String type; - if (n.has("type")) { - type = n.get("type").asText(); - if (!"function".equals(type)) - return null; - } - return n.get("name").asText(); - }).toList(); + return tools.valueStream() + .filter(n -> "function".equals(n.path("type").asText(""))) + .map(n -> n.path("name").asText("")) + .filter(name -> !name.isEmpty()) + .toList(); } @Override @@ -36,6 +32,6 @@ public long getRequestedMaxOutputTokens() { @Override public void setMaxOutputTokens(int maxOutputTokens) { - json.put("max_output_tokens", maxOutputTokens); + json.put("max_output_tokens", maxOutputTokens); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java index a04135d0b6..fa69191180 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java @@ -9,6 +9,7 @@ import java.time.Instant; import java.util.concurrent.atomic.AtomicLong; +import static java.lang.Long.MAX_VALUE; import static java.time.Instant.now; /** @@ -19,7 +20,7 @@ public class AiApiLimit { private static final Logger log = LoggerFactory.getLogger(AiApiLimit.class); - private int maxTokens; + private long maxTokens = MAX_VALUE; private int period; private final Object lock = new Object(); @@ -37,6 +38,9 @@ public class AiApiLimit { * @return Estimated remaining tokens after this call. */ public long checkLimit(long tokensForNextRequest) { + if (maxTokens == MAX_VALUE) { + return MAX_VALUE; + } synchronized (lock) { Instant now = now(); if (nextReset == null || now.isAfter(nextReset)) { @@ -56,17 +60,17 @@ public void addTokens(long tokens) { } } - public int getMaxTokens() { + public long getMaxTokens() { return maxTokens; } /** * @description Maximum number of tokens that can be used within a period. - * @default 0 (no limit) + * @default MAX_VALUE (no limit) * @param maxTokens Maximum number of tokens */ @MCAttribute - public void setMaxTokens(int maxTokens) { + public void setMaxTokens(long maxTokens) { this.maxTokens = maxTokens; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java index 91bdf153fb..38e39648f1 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java @@ -7,7 +7,9 @@ public class AiApiUser { private String name; - private String token; + private String apiKey; + + private long tokens; public String getName() { return name; @@ -22,13 +24,33 @@ public void setName(String name) { this.name = name; } - public String getToken() { - return token; + public String getApiKey() { + return apiKey; } + /** + * @description API key to authenticate the user at the llm gateway + * @default (not set) + * @param apikey to authenticate the user + */ @MCAttribute() - public void setToken(String token) { - this.token = token; + public void setApiKey(String apikey) { + this.apiKey = apikey; + } + + + public long getTokens() { + return tokens; + } + + /** + * @description Number of tokens that the user has available within the current period. + * @default 0 (no limit) + * @param tokens available tokens + */ + @MCAttribute + public void setTokens(long tokens) { + this.tokens = tokens; } @Override diff --git a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java index 515efe0803..644c5ca414 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/util/json/JsonUtil.java @@ -93,9 +93,16 @@ public static JsonNode scalarAsJson(String value) { return FACTORY.textNode(value); } - public static Optional getJsonObject(String s) { + /** + * Get JSON object from message body. + * The caller must deal with the possibility that the body is not a JSON object or + * there are parsing errors. + * @param jsonString String with a JSON body + * @return JSON object or empty if the body is not a JSON object or there are parsing errors + */ + public static Optional getJsonObject(String jsonString) { try { - var node = om.readTree(s); + var node = om.readTree(jsonString); if (node instanceof ObjectNode on) { return Optional.of(on); } @@ -106,7 +113,13 @@ public static Optional getJsonObject(String s) { return empty(); } - + /** + * Get JSON object from message body. + * The caller must deal with the possibility that the body is not a JSON object or + * there are parsing errors. + * @param msg With a JSON body + * @return JSON object or empty if the body is not a JSON object or there are parsing errors + */ public static Optional getJsonObject(Message msg) { return getJsonObjectFromSteam(msg.getBodyAsStreamDecoded()); } @@ -117,9 +130,9 @@ private static Optional getJsonObjectFromSteam(InputStream obj) { if (node instanceof ObjectNode on) { return Optional.of(on); } - log.info("Expected JSON Object but got: {}",node.getNodeType()); + log.debug("Expected JSON Object but got: {}",node.getNodeType()); } catch (Exception e) { - log.info("Error reading JSON: {}", e.getMessage()); + log.debug("Error reading JSON: {}", e.getMessage()); } return empty(); } From 169cb2cbf7c447349ffbaf6b77abb220dce6b981 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 20 May 2026 15:22:05 +0200 Subject: [PATCH 22/43] refactor: remove outdated AI API limit management and centralize error handling - Deleted `AiApiLimit` and `NoAiApiLimit` classes, consolidating token management into `SimpleAiApiStore`. - Introduced `LLMErrorCreator` and its implementations (`OpenAiErrorCreator`, etc.) for reusable error generation. - Refactored `LLMGatewayInterceptor` to utilize provider-specific error creators, simplifying token and model validation. - Enhanced `SimpleAiApiStore` with token reset functionality and user-specific token tracking. - Updated tutorials and examples to align with this refactored approach. --- .../core/interceptor/ai/LLMApiUtil.java | 84 ----------------- .../interceptor/ai/LLMGatewayInterceptor.java | 18 ++-- .../ai/provider/AbstractLLMErrorCreator.java | 33 +++++++ .../ai/provider/LLMErrorCreator.java | 16 ++++ .../interceptor/ai/provider/LLMProvider.java | 1 + .../interceptor/ai/provider/LLMRequest.java | 2 +- .../ai/provider/claude/ClaudeProvider.java | 8 +- .../ai/provider/google/GoogleProvider.java | 8 +- .../ai/provider/openai/OpenAIProvider.java | 6 ++ .../provider/openai/OpenAiErrorCreator.java | 48 ++++++++++ .../openai/OpenAiLLMResponsesRequest.java | 2 +- .../core/interceptor/ai/store/AiApiLimit.java | 93 ------------------- .../core/interceptor/ai/store/AiApiStore.java | 2 + .../core/interceptor/ai/store/AiApiUser.java | 20 ++++ .../ai/store/JDBCAiApiUsageStore.java | 5 + .../interceptor/ai/store/NoAiApiLimit.java | 12 --- .../ai/store/SimpleAiApiStore.java | 51 +++++++--- .../ai/llm-gateway/20-Sharing-API-Keys.yaml | 55 +++++++++++ 18 files changed, 251 insertions(+), 213 deletions(-) delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMErrorCreator.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java create mode 100644 distribution/tutorials/ai/llm-gateway/20-Sharing-API-Keys.yaml diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java deleted file mode 100644 index 2f86121750..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMApiUtil.java +++ /dev/null @@ -1,84 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.predic8.membrane.core.http.Response; - -import java.util.Collection; - -import static com.predic8.membrane.core.http.Header.WWW_AUTHENTICATE; -import static com.predic8.membrane.core.http.Response.badRequest; -import static com.predic8.membrane.core.http.Response.unauthorized; - -public class LLMApiUtil { - - private static final ObjectMapper om = new ObjectMapper(); - - public static Response modelNotAllowed(String model, Collection allowedModels) { - return badRequest().json(createJson(new ErrorEnvelope( - new ErrorBody( - "Model '%s' is not allowed. Allowed models: %s." - .formatted(model, String.join(", ", allowedModels)), - "invalid_request_error", - null, - "model_not_allowed" - ) - ))).build(); - } - - public static Response authenticationFailed() { - return unauthorized().header(WWW_AUTHENTICATE, "Bearer").json(createJson(new ErrorEnvelope( - new ErrorBody( - "Invalid authentication credentials", - "invalid_request_error", - null, - "invalid_authentication" - ) - ))).build(); - } - - public static Response contextLengthExceeded(long maxTokens, long estimatedTokens) { - return badRequest().json(createJson(new ErrorEnvelope(new ErrorBody( - """ - This model's maximum context length is %d tokens. - Your request contains approximately %d tokens. - """.formatted(maxTokens, estimatedTokens).trim(), - "invalid_request_error", - "input", - "context_length_exceeded" - )))).build(); - } - - public static Response tokenLimitExceeded() { - return badRequest() - .json(createJson(new ErrorEnvelope( - new ErrorBody( - "Token rate limit exceeded.", - "rate_limit_error", - null, - "token_limit_exceeded" - ) - ))) - .build(); - } - - public static String createJson(Object o) { - try { - return om.writeValueAsString(o); - } catch (Exception e) { - return """ - { "error": "Could not create JSON" } - """; - } - } - - record ErrorEnvelope(ErrorBody error) { - } - - record ErrorBody( - String message, - String type, - String param, - String code - ) { - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index c020a62b58..a841e2f922 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -6,6 +6,7 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.AbstractInterceptor; import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.interceptor.ai.provider.LLMErrorCreator; import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; import com.predic8.membrane.core.interceptor.ai.store.AiApiStore; @@ -18,7 +19,6 @@ import static com.predic8.membrane.core.exceptions.ProblemDetails.user; import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; import static com.predic8.membrane.core.interceptor.Outcome.RETURN; -import static com.predic8.membrane.core.interceptor.ai.LLMApiUtil.*; import static com.predic8.membrane.core.util.json.JsonUtil.setJsonBody; /* @@ -42,6 +42,7 @@ public class LLMGatewayInterceptor extends AbstractInterceptor { public static final String MEMBRANE_AI_USER = "membrane.ai.user"; private LLMProvider provider; + private LLMErrorCreator errorCreator; private String apiKey; private int maxOutputTokens; @@ -52,6 +53,7 @@ public class LLMGatewayInterceptor extends AbstractInterceptor { @Override public void init() { + errorCreator = provider.getErrorCreator(); if (store != null) store.init(router); } @@ -79,7 +81,7 @@ public Outcome handleRequest(Exchange exc) { if (store != null) { var opt = store.getUser(aiReq.getApiKey()); if (opt.isEmpty()) { - exc.setResponse(authenticationFailed()); + exc.setResponse(errorCreator.authenticationFailed()); return RETURN; } user = opt.get(); @@ -92,10 +94,10 @@ public Outcome handleRequest(Exchange exc) { log.debug("Estimated input tokens: {}", inputTokens); if (store != null) { var remaining = store.checkLimit(user, inputTokens, maxOutputTokens); - log.debug("Remaining tokens: {}", remaining); + log.debug("User {} has {} remaining tokens left", user, remaining); if (remaining <= 0) { - log.info("Token limit exceeded: {}/{}", inputTokens, maxOutputTokens); - exc.setResponse(tokenLimitExceeded()); + log.info("Token limit exceeded. Remaining: {} input: {} maxOutput: {}",remaining, inputTokens, maxOutputTokens); + exc.setResponse(errorCreator.tokenLimitExceeded(inputTokens+maxOutputTokens, remaining, store.getRemainingResetTime())); return RETURN; } } @@ -112,7 +114,7 @@ public Outcome handleRequest(Exchange exc) { var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); - if (maxOutputTokens != 0 && requestedMaxOutputTokens > maxOutputTokens) { + if (maxOutputTokens != 0 && (requestedMaxOutputTokens == -1 || requestedMaxOutputTokens > maxOutputTokens)) { log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.",requestedMaxOutputTokens, maxOutputTokens); aiReq.setMaxOutputTokens(maxOutputTokens); } @@ -120,7 +122,7 @@ public Outcome handleRequest(Exchange exc) { if (maxInputTokens != 0) { if (inputTokens > maxInputTokens) { log.info("Input tokens {} exceed the limit of {}.",inputTokens, maxInputTokens); - exc.setResponse(contextLengthExceeded(maxInputTokens, inputTokens)); + exc.setResponse(errorCreator.contextLengthExceeded(maxInputTokens, inputTokens)); return RETURN; } } @@ -128,7 +130,7 @@ public Outcome handleRequest(Exchange exc) { if (models != null) { var model = aiReq.getModel(); if (!models.contains(model)) { - exc.setResponse(modelNotAllowed(model, models)); + exc.setResponse(errorCreator.modelNotAllowed(model, models)); return RETURN; } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMErrorCreator.java new file mode 100644 index 0000000000..6e6d739711 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMErrorCreator.java @@ -0,0 +1,33 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.fasterxml.jackson.databind.ObjectMapper; + +public abstract class AbstractLLMErrorCreator implements LLMErrorCreator { + + private static final ObjectMapper om = new ObjectMapper(); + + public static String createJson(Object o) { + try { + return om.writeValueAsString(o); + } catch (Exception e) { + return """ + { "error": "Could not create JSON" } + """; + } + } + + public String envelope(String message, String type, String param, String code) { + return createJson(new ErrorEnvelope(new ErrorBody(message,type,param,code))); + } + + private record ErrorEnvelope(ErrorBody error) { + } + + private record ErrorBody( + String message, + String type, + String param, + String code + ) { + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java new file mode 100644 index 0000000000..cd36b17738 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java @@ -0,0 +1,16 @@ +package com.predic8.membrane.core.interceptor.ai.provider; + +import com.predic8.membrane.core.http.Response; + +import java.util.Collection; + +public interface LLMErrorCreator { + + Response tokenLimitExceeded(long tokenRequired, long tokenRemaining, long tokenResetInSeconds); + + Response modelNotAllowed(String model, Collection allowedModels); + + Response authenticationFailed(); + + Response contextLengthExceeded(long maxTokens, long estimatedTokens); +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java index 102ab072d1..5b52994751 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java @@ -8,5 +8,6 @@ public interface LLMProvider { LLMRequest getLLMRequest(Exchange request); LLMResponse getLLMResponse(Exchange request, Consumer postProcessor); + LLMErrorCreator getErrorCreator(); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java index 83ad121a89..a6f377686b 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java @@ -14,7 +14,7 @@ public interface LLMRequest { /** * The max number of tokens that the model is allowed to generate as specified by the client. - * @return The max number of tokens that the model is allowed to generate. + * @return The max number of tokens that the model is allowed to generate. -1 if no limit is set. */ long getRequestedMaxOutputTokens(); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java index 6be0db1566..08ea66e425 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java @@ -2,6 +2,7 @@ import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.provider.LLMErrorCreator; import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; @@ -9,7 +10,7 @@ import java.util.function.Consumer; /** - * @description Anthroic Claude provider configuration + * @description (Experimental) Anthroic Claude provider configuration * Use to configure a LLM gateway to use the anthropic API */ @MCElement( name="claude") @@ -24,4 +25,9 @@ public LLMRequest getLLMRequest(Exchange exchange) { public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { return new ClaudeLLMResponse(request, postProcessor); } + + @Override + public LLMErrorCreator getErrorCreator() { + return null; + } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java index 729c97df98..8c3eab8b41 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java @@ -2,6 +2,7 @@ import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.provider.LLMErrorCreator; import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; @@ -9,7 +10,7 @@ import java.util.function.Consumer; /** - * @description Google AI provider configuration + * @description (Experimental)Google AI provider configuration * Use to configure a LLM gateway to use the Google LLM API */ @MCElement( name="google",id = "google-ai-provider") @@ -24,4 +25,9 @@ public LLMRequest getLLMRequest(Exchange exchange) { public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { return new GoogleLLMResponse(request, postProcessor); } + + @Override + public LLMErrorCreator getErrorCreator() { + return null; + } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java index 01af8104a3..8a1aa29436 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java @@ -2,6 +2,7 @@ import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.ai.provider.LLMErrorCreator; import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; @@ -32,6 +33,11 @@ public LLMResponse getLLMResponse(Exchange exchange, Consumer postP return new OpenAiChatCompletionsLLMResponse(exchange, postProcessor); } + @Override + public LLMErrorCreator getErrorCreator() { + return new OpenAiErrorCreator(); + } + static boolean isResponsesApi(Exchange exchange) { return exchange.getRequest().getUri().startsWith("/v1/responses"); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java new file mode 100644 index 0000000000..b919a73ef9 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java @@ -0,0 +1,48 @@ +package com.predic8.membrane.core.interceptor.ai.provider.openai; + +import com.predic8.membrane.core.http.Response; +import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMErrorCreator; + +import java.util.Collection; + +import static com.predic8.membrane.core.http.Header.WWW_AUTHENTICATE; +import static com.predic8.membrane.core.http.Response.*; + +public class OpenAiErrorCreator extends AbstractLLMErrorCreator { + + public Response tokenLimitExceeded(long tokenRequired, long tokenRemaining, long tokenResetInSeconds) { + return statusCode(429).json(envelope( + "Token rate limit exceeded. Request requires %d tokens but only %d remain. Please wait %d seconds before retrying.".formatted(tokenRequired, tokenRemaining, tokenResetInSeconds), + "rate_limit_error", + null, + "token_limit_exceeded")).build(); + } + + public Response modelNotAllowed(String model, Collection allowedModels) { + return badRequest().json(envelope( + "Model '%s' is not allowed. Allowed models: %s." + .formatted(model, String.join(", ", allowedModels)), + "invalid_request_error", + null, + "model_not_allowed")).build(); + } + + public Response authenticationFailed() { + return unauthorized().header(WWW_AUTHENTICATE, "Bearer").json(envelope( + "Invalid authentication credentials", + "invalid_request_error", + null, + "invalid_authentication")).build(); + } + + public Response contextLengthExceeded(long maxTokens, long estimatedTokens) { + return badRequest().json(envelope( + """ + This model's maximum context length is %d tokens. + Your request contains approximately %d tokens. + """.formatted(maxTokens, estimatedTokens).trim(), + "invalid_request_error", + "input", + "context_length_exceeded")).build(); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java index e6547d7e25..2568755848 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java @@ -27,7 +27,7 @@ public List getTools() { public long getRequestedMaxOutputTokens() { if (json.has("max_output_tokens")) return json.get("max_output_tokens").asLong(); - return 0; + return -1; } @Override diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java deleted file mode 100644 index fa69191180..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiLimit.java +++ /dev/null @@ -1,93 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.store; - -import com.predic8.membrane.annot.MCAttribute; -import com.predic8.membrane.annot.MCElement; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.annotation.concurrent.GuardedBy; -import java.time.Instant; -import java.util.concurrent.atomic.AtomicLong; - -import static java.lang.Long.MAX_VALUE; -import static java.time.Instant.now; - -/** - * @description Limits the number of tokens that can be used for a specific API. - */ -@MCElement(name = "limit", component = false, id = "ai-api-limit") -public class AiApiLimit { - - private static final Logger log = LoggerFactory.getLogger(AiApiLimit.class); - - private long maxTokens = MAX_VALUE; - private int period; - - private final Object lock = new Object(); - - @GuardedBy("lock") - private Instant nextReset; - - private final AtomicLong tokens = new AtomicLong(0); - - /** - * Checks if the user has enough tokens to make the request. - * If there aren't enough tokens for the request, 0 or a negative number is returned. - * - * @param tokensForNextRequest Estimation of the number of tokens that will be used for the next request. - * @return Estimated remaining tokens after this call. - */ - public long checkLimit(long tokensForNextRequest) { - if (maxTokens == MAX_VALUE) { - return MAX_VALUE; - } - synchronized (lock) { - Instant now = now(); - if (nextReset == null || now.isAfter(nextReset)) { - tokens.set(0); - nextReset = now.plusSeconds(period); - log.debug("Resetting AI API usage limit."); - } - } - - return maxTokens - tokens.get() - tokensForNextRequest; - } - - public void addTokens(long tokens) { - synchronized (lock) { - log.debug("Adding {} tokens to AI API usage limit.", tokens); - this.tokens.addAndGet(tokens); - } - } - - public long getMaxTokens() { - return maxTokens; - } - - /** - * @description Maximum number of tokens that can be used within a period. - * @default MAX_VALUE (no limit) - * @param maxTokens Maximum number of tokens - */ - @MCAttribute - public void setMaxTokens(long maxTokens) { - this.maxTokens = maxTokens; - } - - public int getPeriod() { - return period; - } - - /** - * @description Period after which the token limit resets. - * @default 0 (no limit) - * @param period in seconds - */ - @MCAttribute - public void setPeriod(int period) { - synchronized (lock) { - this.period = period; - nextReset = now().plusSeconds(period); - } - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java index 4f3a4c900c..d78da8fcd3 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java @@ -23,5 +23,7 @@ default void init(Router router) { * @return */ long checkLimit(AiApiUser user, long inputTokens, long outputTokens); + + long getRemainingResetTime(); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java index 38e39648f1..257ef86a28 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java @@ -3,6 +3,8 @@ import com.predic8.membrane.annot.MCAttribute; import com.predic8.membrane.annot.MCElement; +import java.util.concurrent.atomic.AtomicLong; + @MCElement(name = "users", component = false, id="ai-api-users") public class AiApiUser { @@ -11,6 +13,24 @@ public class AiApiUser { private long tokens; + private AtomicLong tokensUsedInPeriod = new AtomicLong(); + + public void addTokensUsedInPeriod(Usage usage) { + tokensUsedInPeriod.addAndGet(usage.totalTokens()); + } + + public void resetTokensUsedInPeriod() { + tokensUsedInPeriod.set(0); + } + + public long getTokensUsedInPeriod() { + return tokensUsedInPeriod.get(); + } + + public long checkLimit(long tokensNeededForRequest) { + return this.tokens - tokensUsedInPeriod.get() - tokensNeededForRequest; + } + public String getName() { return name; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java index d4ad25e009..16457a97db 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java @@ -64,6 +64,11 @@ public long checkLimit(AiApiUser user, long inputTokens, long outputTokens) { return 0; } + @Override + public long getRemainingResetTime() { + return 0; + } + private void createTablesIfNotExist() { try (var connection = getConnection(); var ps = connection.prepareStatement(CREATE_TABLE_SQL)) { ps.executeUpdate(); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java deleted file mode 100644 index bf2a74a464..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/NoAiApiLimit.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.store; - -/** - * @description Store that does not limit the number of AI API calls (experimental). - */ -public class NoAiApiLimit extends AiApiLimit{ - - @Override - public long checkLimit(long tokensForNextRequest) { - return 1000; // Returns a value greater than 0 to indicate that the request can be processed. - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index 5b1697d849..d6af5d336d 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -1,14 +1,19 @@ package com.predic8.membrane.core.interceptor.ai.store; +import com.predic8.membrane.annot.MCAttribute; import com.predic8.membrane.annot.MCChildElement; import com.predic8.membrane.annot.MCElement; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.concurrent.GuardedBy; +import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Optional; +import static java.time.Instant.now; + /** * @description Simple store for the LLM Gateway that stores limits in memory. Users and keys can * be configured in the configuration file. @@ -19,20 +24,26 @@ public class SimpleAiApiStore implements AiApiStore { private static final Logger log = LoggerFactory.getLogger(SimpleAiApiStore.class); private List users = new ArrayList<>(); - private AiApiLimit limit = new NoAiApiLimit(); private boolean logUsage = true; + private final Object lock = new Object(); + + @GuardedBy("lock") + private Instant nextReset; + + private long limitResetPeriod = 60; + @Override public void store(AiApiUser user, Usage usage) { if (logUsage) log.info("user: {} {}",user.getName(),usage.toString()); - limit.addTokens(usage.totalTokens()); + user.addTokensUsedInPeriod(usage); } @Override public Optional getUser(String token) { - return users.stream().filter(u -> u.getToken().equals(token)).findFirst(); + return users.stream().filter(u -> u.getApiKey().equals(token)).findFirst(); } @Override @@ -40,9 +51,26 @@ public long checkLimit(AiApiUser user, long inputTokens, long outputTokens) { if (user == null) return 0; // anonymous user gets no tokens - return limit.checkLimit(inputTokens + outputTokens); + synchronized (lock) { + var now = now(); + if (nextReset == null || now.isAfter(nextReset)) { + + nextReset = now.plusSeconds(limitResetPeriod); + log.info("Resetting AI API token usage limit."); + } + } + + return user.checkLimit(inputTokens + outputTokens); } + @Override + public long getRemainingResetTime() { + synchronized (lock) { + return nextReset == null ? 0 : (nextReset.toEpochMilli() - now().toEpochMilli()) / 1000; + } + } + + /** * List of users that can be used for authentication. * @param users User list @@ -56,18 +84,17 @@ public List getUsers() { return users; } - public AiApiLimit getLimit() { - return limit; + public long getLimitResetPeriod() { + return limitResetPeriod; } /** - * @description The limit of tokens that can be used for each user. - * @default 0 (no limit) - * @param limit + * @description The period in seconds after which the token limit is reset. + * @param limitResetPeriod in seconds, e.g. 3600 for 1 hour */ - @MCChildElement(allowForeign = true) - public void setLimit(AiApiLimit limit) { - this.limit = limit; + @MCAttribute + public void setLimitResetPeriod(long limitResetPeriod) { + this.limitResetPeriod = limitResetPeriod; } public boolean isLogUsage() { diff --git a/distribution/tutorials/ai/llm-gateway/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/20-Sharing-API-Keys.yaml new file mode 100644 index 0000000000..ffaa8ca42d --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/20-Sharing-API-Keys.yaml @@ -0,0 +1,55 @@ +# yaml-language-server: $schema=https://www.membrane-api.io/v7.2.1.json +# +# Tutorial: Sharing LLM API Keys +# +# Replace <> with your OpenAI API key. +# +# Requests: +# +# 1. Hello AI +# curl -H "Content-Type: application/json" -H "Authorization: Bearer abc123" -d @simple.json http://localhost:2000/v1/responses +# Check: Successful response +# +# 2. Token Limit Exceeded +# Repeat the previous request until you receive: 429 Token Limit Exceeded +# User alice is blocked after the limit is exceeded. Bob should still be able to send requests. +# +# 3. Wrong Model +# curl -v -H "Content-Type: application/json" -H "Authorization: Bearer abc123" -d @wrong-model.json http://localhost:2000/v1/responses +# Check: Error response +# +# 4. Max. Input Tokens Exceeded +# curl -v -H "Content-Type: application/json" -H "Authorization: Bearer abc123" -d @max-input.json http://localhost:2000/v1/responses +# Check: Error response +# +# 5. Requested Max. Output Tokens Exceeded +# curl -v -H "Content-Type: application/json" -H "Authorization: Bearer abc123" -d @max-output.json http://localhost:2000/v1/responses +# Check: Field max_output_tokens in the response + +api: + port: 2000 + flow: + - llmGateway: + # Replace <> with your OpenAI API key + apiKey: <> + # Limits per request + maxInputTokens: 100 + maxOutputTokens: 200 + models: + - gpt-5.4 + - gpt-5-nano + - gpt-5-mini + openai: {} + simpleStore: + # User-facing API keys for the LLM Gateway + users: + - name: alice + apiKey: abc123 + tokens: 500 # Token limit for alice + - name: bob + apiKey: qwertz + tokens: 10000 + # Time in seconds after which the token limit is reset + limitResetPeriod: 60 + target: + url: https://api.openai.com/ From 0af0f350394abdfdea2a79a4793e65a9838096b9 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 20 May 2026 15:34:09 +0200 Subject: [PATCH 23/43] refactor: enhance token usage tracking and improve inline documentation - Added `resetTokensUsedInPeriod` for user-specific token reset in `SimpleAiApiStore`. - Improved inline documentation for methods across `AiApiUser`, `AiApiStore`, and `LLMGatewayInterceptor`. - Updated parameter descriptions for clarity and consistency. --- .../interceptor/ai/LLMGatewayInterceptor.java | 8 ++++---- .../core/interceptor/ai/store/AiApiStore.java | 4 ++-- .../core/interceptor/ai/store/AiApiUser.java | 15 ++++++++++----- .../interceptor/ai/store/SimpleAiApiStore.java | 2 +- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index a841e2f922..3d648a5eb8 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -164,8 +164,8 @@ public String getApiKey() { } /** - * @param apiKey * @description API key for the LLM provider. Specify here the API key from OpenAI or Anthropic. + * @param apiKey LLM provider API key */ @MCAttribute public void setApiKey(String apiKey) { @@ -197,9 +197,9 @@ public int getMaxOutputTokens() { } /** - * @param maxOutputTokens * @description Maximum number of tokens the LLM should use to generate a response. This is just a hint that the gateway * sends to the LLM provider. The provider may use a different limit. + * @param maxOutputTokens Maximum number of tokens the LLM should use to generate a response. * @default 0 (unlimited) */ @MCAttribute @@ -212,9 +212,9 @@ public int getMaxInputTokens() { } /** - * @param maxInputTokens * @description Restricts token usage for the input. The size of the input is estimated by gateway based on the request size. * Actual token usage may be deviate from this value. + * @param maxInputTokens Maximum number of tokens that a request can use. */ @MCAttribute public void setMaxInputTokens(int maxInputTokens) { @@ -227,8 +227,8 @@ public List getModels() { /** * @desciption Restricts the models that can be used by the gateway. - * @param models * @default null (no restriction) + * @param models List of models that can be used by the gateway. */ @MCAttribute public void setModels(List models) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java index d78da8fcd3..73674eeef1 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java @@ -19,8 +19,8 @@ default void init(Router router) { /** * Checks if the user has enough tokens to make the request. - * @param user - * @return + * @param user The user to check + * @return Estimated number of tokens that the user has left after this request */ long checkLimit(AiApiUser user, long inputTokens, long outputTokens); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java index 257ef86a28..5fd87b2f9b 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java @@ -13,8 +13,12 @@ public class AiApiUser { private long tokens; - private AtomicLong tokensUsedInPeriod = new AtomicLong(); + private final AtomicLong tokensUsedInPeriod = new AtomicLong(); + /** + * Updates the store with the number of tokens used in this call + * @param usage The number of tokens used + */ public void addTokensUsedInPeriod(Usage usage) { tokensUsedInPeriod.addAndGet(usage.totalTokens()); } @@ -23,10 +27,11 @@ public void resetTokensUsedInPeriod() { tokensUsedInPeriod.set(0); } - public long getTokensUsedInPeriod() { - return tokensUsedInPeriod.get(); - } - + /** + * Checks if the user has enough tokens to make the request. + * @param tokensNeededForRequest The number of tokens that the user needs to make the request + * @return The estimated number of tokens that the user has left after this request + */ public long checkLimit(long tokensNeededForRequest) { return this.tokens - tokensUsedInPeriod.get() - tokensNeededForRequest; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index d6af5d336d..bd2976ce00 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -54,9 +54,9 @@ public long checkLimit(AiApiUser user, long inputTokens, long outputTokens) { synchronized (lock) { var now = now(); if (nextReset == null || now.isAfter(nextReset)) { - nextReset = now.plusSeconds(limitResetPeriod); log.info("Resetting AI API token usage limit."); + users.forEach(AiApiUser::resetTokensUsedInPeriod); } } From 329bb80d65d0c01e768e65b0a043944c1bb9ae09 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 20 May 2026 16:39:08 +0200 Subject: [PATCH 24/43] feat: enhance error handling, synchronization, and token tracking - Added synchronized blocks to `SimpleAiApiStore` for thread-safe access to user data. - Introduced `invalidRequestError` to `LLMErrorCreator` and implemented it in `OpenAiErrorCreator`. - Allowed unlimited tokens for users with `MAX_VALUE` in `AiApiUser`. - Simplified logic in `LLMGatewayInterceptor` for token and model validation. - Updated tutorial JSON with test input for validation. --- .../interceptor/ai/LLMGatewayInterceptor.java | 46 ++++++++----------- .../ai/provider/LLMErrorCreator.java | 2 + .../provider/openai/OpenAiErrorCreator.java | 5 ++ .../core/interceptor/ai/store/AiApiUser.java | 6 ++- .../ai/store/SimpleAiApiStore.java | 13 ++++-- .../tutorials/ai/llm-gateway/wrong-model.json | 4 ++ 6 files changed, 45 insertions(+), 31 deletions(-) create mode 100644 distribution/tutorials/ai/llm-gateway/wrong-model.json diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index 3d648a5eb8..e1cb492467 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -16,7 +16,6 @@ import java.util.List; -import static com.predic8.membrane.core.exceptions.ProblemDetails.user; import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; import static com.predic8.membrane.core.interceptor.Outcome.RETURN; import static com.predic8.membrane.core.util.json.JsonUtil.setJsonBody; @@ -65,10 +64,7 @@ public Outcome handleRequest(Exchange exc) { try { aiReq = provider.getLLMRequest(exc); } catch (Exception e) { - user(router.getConfiguration().isProduction(), "AI Gateway") - .title("Invalid request") - .detail("Error parsing request: " + e.getMessage()) - .buildAndSetResponse(exc); + exc.setResponse(errorCreator.invalidRequestError("Error parsing request: " + e.getMessage())); return RETURN; } @@ -86,24 +82,22 @@ public Outcome handleRequest(Exchange exc) { } user = opt.get(); log.debug("User: {}", user); + exc.setProperty(MEMBRANE_AI_USER, user); } - long inputTokens = 0; - if (exc.getRequest().isPOSTRequest()) { - inputTokens = aiReq.estimateInputTokens(); - log.debug("Estimated input tokens: {}", inputTokens); - if (store != null) { - var remaining = store.checkLimit(user, inputTokens, maxOutputTokens); - log.debug("User {} has {} remaining tokens left", user, remaining); - if (remaining <= 0) { - log.info("Token limit exceeded. Remaining: {} input: {} maxOutput: {}",remaining, inputTokens, maxOutputTokens); - exc.setResponse(errorCreator.tokenLimitExceeded(inputTokens+maxOutputTokens, remaining, store.getRemainingResetTime())); - return RETURN; - } + long inputTokens = aiReq.estimateInputTokens(); + log.debug("Estimated input tokens: {}", inputTokens); + + // Check store limits + if (store != null) { + var remaining = store.checkLimit(user, inputTokens, maxOutputTokens); + log.debug("User {} has {} remaining tokens left", user, remaining); + if (remaining <= 0) { + log.info("Token limit exceeded. Remaining: {} input: {} maxOutput: {}", remaining, inputTokens, maxOutputTokens); + exc.setResponse(errorCreator.tokenLimitExceeded(inputTokens + maxOutputTokens, remaining, store.getRemainingResetTime())); + return RETURN; } } - exc.setProperty(MEMBRANE_AI_USER, user); - // If APIKey is specified, use that for the LLM. Overwrites keys from the client if (apiKey != null) { @@ -115,13 +109,13 @@ public Outcome handleRequest(Exchange exc) { var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); if (maxOutputTokens != 0 && (requestedMaxOutputTokens == -1 || requestedMaxOutputTokens > maxOutputTokens)) { - log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.",requestedMaxOutputTokens, maxOutputTokens); + log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.", requestedMaxOutputTokens, maxOutputTokens); aiReq.setMaxOutputTokens(maxOutputTokens); } if (maxInputTokens != 0) { if (inputTokens > maxInputTokens) { - log.info("Input tokens {} exceed the limit of {}.",inputTokens, maxInputTokens); + log.info("Input tokens {} exceed the limit of {}.", inputTokens, maxInputTokens); exc.setResponse(errorCreator.contextLengthExceeded(maxInputTokens, inputTokens)); return RETURN; } @@ -146,7 +140,7 @@ public Outcome handleResponse(Exchange exc) { provider.getLLMResponse(exc, res -> { var user = exc.getProperty(MEMBRANE_AI_USER, AiApiUser.class); - if (log.isInfoEnabled() && user != null) { + if (log.isDebugEnabled() && user != null) { log.debug("Token usage of user {}: {}", user, res.getUsage()); } else { log.info("Token usage: {}", res.getUsage()); @@ -164,8 +158,8 @@ public String getApiKey() { } /** - * @description API key for the LLM provider. Specify here the API key from OpenAI or Anthropic. * @param apiKey LLM provider API key + * @description API key for the LLM provider. Specify here the API key from OpenAI or Anthropic. */ @MCAttribute public void setApiKey(String apiKey) { @@ -197,9 +191,9 @@ public int getMaxOutputTokens() { } /** + * @param maxOutputTokens Maximum number of tokens the LLM should use to generate a response. * @description Maximum number of tokens the LLM should use to generate a response. This is just a hint that the gateway * sends to the LLM provider. The provider may use a different limit. - * @param maxOutputTokens Maximum number of tokens the LLM should use to generate a response. * @default 0 (unlimited) */ @MCAttribute @@ -212,9 +206,9 @@ public int getMaxInputTokens() { } /** + * @param maxInputTokens Maximum number of tokens that a request can use. * @description Restricts token usage for the input. The size of the input is estimated by gateway based on the request size. * Actual token usage may be deviate from this value. - * @param maxInputTokens Maximum number of tokens that a request can use. */ @MCAttribute public void setMaxInputTokens(int maxInputTokens) { @@ -226,9 +220,9 @@ public List getModels() { } /** + * @param models List of models that can be used by the gateway. * @desciption Restricts the models that can be used by the gateway. * @default null (no restriction) - * @param models List of models that can be used by the gateway. */ @MCAttribute public void setModels(List models) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java index cd36b17738..28ddd23cbf 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java @@ -6,6 +6,8 @@ public interface LLMErrorCreator { + Response invalidRequestError(String message); + Response tokenLimitExceeded(long tokenRequired, long tokenRemaining, long tokenResetInSeconds); Response modelNotAllowed(String model, Collection allowedModels); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java index b919a73ef9..6f00a573c6 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java @@ -10,6 +10,11 @@ public class OpenAiErrorCreator extends AbstractLLMErrorCreator { + @Override + public Response invalidRequestError(String message) { + return Response.badRequest().json(envelope(message, "invalid_request_error", null, "bad_request")).build(); + } + public Response tokenLimitExceeded(long tokenRequired, long tokenRemaining, long tokenResetInSeconds) { return statusCode(429).json(envelope( "Token rate limit exceeded. Request requires %d tokens but only %d remain. Please wait %d seconds before retrying.".formatted(tokenRequired, tokenRemaining, tokenResetInSeconds), diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java index 5fd87b2f9b..9f8ff04e08 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java @@ -5,13 +5,15 @@ import java.util.concurrent.atomic.AtomicLong; +import static java.lang.Long.MAX_VALUE; + @MCElement(name = "users", component = false, id="ai-api-users") public class AiApiUser { private String name; private String apiKey; - private long tokens; + private long tokens = MAX_VALUE; private final AtomicLong tokensUsedInPeriod = new AtomicLong(); @@ -33,6 +35,8 @@ public void resetTokensUsedInPeriod() { * @return The estimated number of tokens that the user has left after this request */ public long checkLimit(long tokensNeededForRequest) { + if (tokens == MAX_VALUE) + return MAX_VALUE; return this.tokens - tokensUsedInPeriod.get() - tokensNeededForRequest; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index bd2976ce00..0a79756646 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -8,7 +8,7 @@ import javax.annotation.concurrent.GuardedBy; import java.time.Instant; -import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Optional; @@ -23,7 +23,8 @@ public class SimpleAiApiStore implements AiApiStore { private static final Logger log = LoggerFactory.getLogger(SimpleAiApiStore.class); - private List users = new ArrayList<>(); + @GuardedBy("lock") + private List users = Collections.emptyList(); private boolean logUsage = true; @@ -43,7 +44,9 @@ public void store(AiApiUser user, Usage usage) { @Override public Optional getUser(String token) { - return users.stream().filter(u -> u.getApiKey().equals(token)).findFirst(); + synchronized (lock) { + return users.stream().filter(u -> u.getApiKey().equals(token)).findFirst(); + } } @Override @@ -77,7 +80,9 @@ public long getRemainingResetTime() { */ @MCChildElement(allowForeign = true,order = 10) public void setUsers(List users) { - this.users = users; + synchronized (lock) { + this.users = users; + } } public List getUsers() { diff --git a/distribution/tutorials/ai/llm-gateway/wrong-model.json b/distribution/tutorials/ai/llm-gateway/wrong-model.json new file mode 100644 index 0000000000..7a551564a2 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/wrong-model.json @@ -0,0 +1,4 @@ +{ + "model": "gpt-4", + "input": "Who are you?" +} \ No newline at end of file From 56d6e71b2218d845bb336f996d28d883b955f943 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 20 May 2026 17:56:57 +0200 Subject: [PATCH 25/43] feat: extend LLM Gateway with Claude support and improved error handling - Added Claude-specific error handling with `ClaudeErrorCreator` and `ClaudeErrorResponse`. - Introduced `10-Basic-LLM-Gateway.yaml` tutorial for Claude integration. - Enhanced token usage tracking in `ClaudeLLMResponse`. - Updated examples and tutorials to support both OpenAI and Claude. --- .../interceptor/ai/LLMGatewayInterceptor.java | 2 +- .../ai/provider/LLMErrorCreator.java | 8 +- .../provider/claude/ClaudeErrorCreator.java | 52 +++++++++ .../provider/claude/ClaudeErrorResponse.java | 100 ++++++++++++++++++ .../ai/provider/claude/ClaudeLLMResponse.java | 24 ++++- .../ai/provider/claude/ClaudeProvider.java | 2 +- .../provider/openai/OpenAiErrorCreator.java | 4 +- .../core/interceptor/ai/store/AiApiUser.java | 4 +- .../ai/store/SimpleAiApiStore.java | 4 +- .../ai/llm-gateway/10-Basic-LLM-Gateway.yaml | 28 ----- .../claude/10-Basic-LLM-Gateway.yaml | 27 +++++ .../ai/llm-gateway/claude/max-input.json | 10 ++ .../ai/llm-gateway/claude/max-output.json | 10 ++ .../ai/llm-gateway/claude/simple.json | 10 ++ .../openai/10-Basic-LLM-Gateway.yaml | 26 +++++ .../{ => openai}/20-Sharing-API-Keys.yaml | 2 +- .../llm-gateway/{ => openai}/max-input.json | 0 .../llm-gateway/{ => openai}/max-output.json | 0 .../ai/llm-gateway/{ => openai}/membrane.cmd | 0 .../ai/llm-gateway/{ => openai}/membrane.sh | 0 .../ai/llm-gateway/{ => openai}/simple.json | 0 .../llm-gateway/{ => openai}/wrong-model.json | 0 22 files changed, 275 insertions(+), 38 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorResponse.java delete mode 100644 distribution/tutorials/ai/llm-gateway/10-Basic-LLM-Gateway.yaml create mode 100644 distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml create mode 100644 distribution/tutorials/ai/llm-gateway/claude/max-input.json create mode 100644 distribution/tutorials/ai/llm-gateway/claude/max-output.json create mode 100644 distribution/tutorials/ai/llm-gateway/claude/simple.json create mode 100644 distribution/tutorials/ai/llm-gateway/openai/10-Basic-LLM-Gateway.yaml rename distribution/tutorials/ai/llm-gateway/{ => openai}/20-Sharing-API-Keys.yaml (97%) rename distribution/tutorials/ai/llm-gateway/{ => openai}/max-input.json (100%) rename distribution/tutorials/ai/llm-gateway/{ => openai}/max-output.json (100%) rename distribution/tutorials/ai/llm-gateway/{ => openai}/membrane.cmd (100%) rename distribution/tutorials/ai/llm-gateway/{ => openai}/membrane.sh (100%) rename distribution/tutorials/ai/llm-gateway/{ => openai}/simple.json (100%) rename distribution/tutorials/ai/llm-gateway/{ => openai}/wrong-model.json (100%) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index e1cb492467..a4ba53f7b8 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -116,7 +116,7 @@ public Outcome handleRequest(Exchange exc) { if (maxInputTokens != 0) { if (inputTokens > maxInputTokens) { log.info("Input tokens {} exceed the limit of {}.", inputTokens, maxInputTokens); - exc.setResponse(errorCreator.contextLengthExceeded(maxInputTokens, inputTokens)); + exc.setResponse(errorCreator.inputTokensExceeded(maxInputTokens, inputTokens)); return RETURN; } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java index 28ddd23cbf..f81bda7445 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java @@ -14,5 +14,11 @@ public interface LLMErrorCreator { Response authenticationFailed(); - Response contextLengthExceeded(long maxTokens, long estimatedTokens); + /** + * + * @param maxTokens + * @param estimatedTokens + * @return + */ + Response inputTokensExceeded(long maxTokens, long estimatedTokens); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java new file mode 100644 index 0000000000..9cf25f7ae5 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java @@ -0,0 +1,52 @@ +package com.predic8.membrane.core.interceptor.ai.provider.claude; + +import com.predic8.membrane.core.http.Response; +import com.predic8.membrane.core.interceptor.ai.provider.LLMErrorCreator; +import com.predic8.membrane.core.interceptor.ai.provider.claude.ClaudeErrorResponse.ClaudeError; + +import java.util.Collection; +import java.util.UUID; + +import static com.predic8.membrane.core.http.Response.badRequest; + +public class ClaudeErrorCreator implements LLMErrorCreator { + + // Claude error types + private static final String RATE_LIMIT_ERROR = "rate_limit_error"; + + @Override + public Response invalidRequestError(String message) { + return null; + } + + @Override + public Response tokenLimitExceeded(long tokenRequired, long tokenRemaining, long tokenResetInSeconds) { + return null; + } + + @Override + public Response modelNotAllowed(String model, Collection allowedModels) { + return null; + } + + @Override + public Response authenticationFailed() { + return null; + } + + @Override + public Response inputTokensExceeded(long maxTokens, long estimatedTokens) { + var json = ClaudeErrorResponse.builder().error( + ClaudeError.builder().type(RATE_LIMIT_ERROR) + .message(""" + prompt is too long: + %d tokens > %d maximum + """.formatted(estimatedTokens, maxTokens).trim()) + ).requestId("membrane_" + UUID.randomUUID()) + .toJson(); + + return badRequest() + .json(json) + .build(); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorResponse.java new file mode 100644 index 0000000000..2bdce96c2e --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorResponse.java @@ -0,0 +1,100 @@ +package com.predic8.membrane.core.interceptor.ai.provider.claude; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class ClaudeErrorResponse { + + private static final ObjectMapper om = new ObjectMapper(); + + private String type = "error"; + private ClaudeError error; + private String request_id; + + public static ClaudeErrorResponse builder() { + return new ClaudeErrorResponse(); + } + + public String getType() { + return type; + } + + public ClaudeErrorResponse type(String type) { + this.type = type; + return this; + } + + public ClaudeError getError() { + return error; + } + + public ClaudeErrorResponse error(ClaudeError error) { + this.error = error; + return this; + } + + public String getRequest_id() { + return request_id; + } + + public ClaudeErrorResponse requestId(String requestId) { + this.request_id = requestId; + return this; + } + + public String toJson() { + try { + return om.writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize ClaudeErrorResponse", e); + } + } + + @JsonIgnoreProperties(ignoreUnknown = true) + public static class ClaudeError { + + private String type; + private String message; + + public static ClaudeError builder() { + return new ClaudeError(); + } + + public String getType() { + return type; + } + + public ClaudeError type(String type) { + this.type = type; + return this; + } + + public String getMessage() { + return message; + } + + public ClaudeError message(String message) { + this.message = message; + return this; + } + + @Override + public String toString() { + return "ClaudeError{" + + "type='" + type + '\'' + + ", message='" + message + '\'' + + '}'; + } + } + + @Override + public String toString() { + return "ClaudeErrorResponse{" + + "type='" + type + '\'' + + ", error=" + error + + ", request_id='" + request_id + '\'' + + '}'; + } +} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java index dbdbbe3de4..f487d43b0d 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java @@ -1,5 +1,6 @@ package com.predic8.membrane.core.interceptor.ai.provider.claude; +import com.fasterxml.jackson.databind.JsonNode; import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; @@ -57,9 +58,30 @@ public void process(SSEEvent event) { } } + Usage extractUsage() { + + var usage = json.path("usage"); + + var inputTokens = getInputTokens(usage); + var outputTokens = getOutputTokens(usage); + var totalTokens = inputTokens + outputTokens; + return new Usage(inputTokens, outputTokens, totalTokens); + + } + + protected static int getOutputTokens(JsonNode usage) { + return usage.path("output_tokens").asInt(0); + } + + protected static int getInputTokens(JsonNode usage) { + return usage.path("input_tokens").asInt(0); + } + @Override public Usage getUsage() { - return usage; + if (usage != null) + return usage; + return usage = extractUsage(); } } \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java index 08ea66e425..99ba4820e7 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java @@ -28,6 +28,6 @@ public LLMResponse getLLMResponse(Exchange request, Consumer postPr @Override public LLMErrorCreator getErrorCreator() { - return null; + return new ClaudeErrorCreator(); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java index 6f00a573c6..7f51494ad5 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java @@ -12,7 +12,7 @@ public class OpenAiErrorCreator extends AbstractLLMErrorCreator { @Override public Response invalidRequestError(String message) { - return Response.badRequest().json(envelope(message, "invalid_request_error", null, "bad_request")).build(); + return badRequest().json(envelope(message, "invalid_request_error", null, "bad_request")).build(); } public Response tokenLimitExceeded(long tokenRequired, long tokenRemaining, long tokenResetInSeconds) { @@ -40,7 +40,7 @@ public Response authenticationFailed() { "invalid_authentication")).build(); } - public Response contextLengthExceeded(long maxTokens, long estimatedTokens) { + public Response inputTokensExceeded(long maxTokens, long estimatedTokens) { return badRequest().json(envelope( """ This model's maximum context length is %d tokens. diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java index 9f8ff04e08..cd3ab76b4b 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java @@ -13,7 +13,7 @@ public class AiApiUser { private String name; private String apiKey; - private long tokens = MAX_VALUE; + private long tokens = 0; private final AtomicLong tokensUsedInPeriod = new AtomicLong(); @@ -35,7 +35,7 @@ public void resetTokensUsedInPeriod() { * @return The estimated number of tokens that the user has left after this request */ public long checkLimit(long tokensNeededForRequest) { - if (tokens == MAX_VALUE) + if (tokens == 0) return MAX_VALUE; return this.tokens - tokensUsedInPeriod.get() - tokensNeededForRequest; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index 0a79756646..0793a2c75d 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -86,7 +86,9 @@ public void setUsers(List users) { } public List getUsers() { - return users; + synchronized (lock) { + return users; + } } public long getLimitResetPeriod() { diff --git a/distribution/tutorials/ai/llm-gateway/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/10-Basic-LLM-Gateway.yaml deleted file mode 100644 index edf4e5b368..0000000000 --- a/distribution/tutorials/ai/llm-gateway/10-Basic-LLM-Gateway.yaml +++ /dev/null @@ -1,28 +0,0 @@ -# yaml-language-server: $schema=https://www.membrane-api.io/v7.2.1.json -# -# Tutorial: Basic LLM Gateway -# -# Replace <> with your OpenAI API key. -# -# 1. Hello World -# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @simple.json http://localhost:2000/v1/responses -# -# 2. Exceed the input token limit -# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-input.json http://localhost:2000/v1/responses -# => Returns an error because the request exceeds maxInputTokens. -# -# 3. Exceed the output token limit -# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-output.json http://localhost:2000/v1/responses -# => Check the max_output_tokens field in the response. - -api: - port: 2000 - flow: - - llmGateway: - openai: {} - maxInputTokens: 100 - maxOutputTokens: 200 - - request: - - log: {} - target: - url: https://api.openai.com diff --git a/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml new file mode 100644 index 0000000000..5a9704c845 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml @@ -0,0 +1,27 @@ +# yaml-language-server: $schema=https://www.membrane-api.io/v7.2.1.json +# +# Tutorial: Basic LLM Gateway (Antropic Claude) +# +# Replace <> with your OpenAI API key. +# +# 1. Hello World +# curl -v -H "Content-Type: application/json" -H "x-api-key: <>" -H "anthropic-version: 2023-06-01" -d @simple.json http://localhost:2000/v1/messages +# Check the response and the Membrane logs. +# +# 2. Exceed the input token limit +# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-input.json http://localhost:2000/v1/responses +# Returns an error because the request exceeds maxInputTokens. +# +# 3. Exceed the output token limit +# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-output.json http://localhost:2000/v1/responses +# Check the Membrane log for limiting max tokens to 200 + +api: + port: 2000 + flow: + - llmGateway: + claude: {} + maxInputTokens: 100 + maxOutputTokens: 200 + target: + url: https://api.anthropic.com \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/claude/max-input.json b/distribution/tutorials/ai/llm-gateway/claude/max-input.json new file mode 100644 index 0000000000..a51d79d50e --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/claude/max-input.json @@ -0,0 +1,10 @@ +{ + "model": "claude-sonnet-4-0", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "Who are you, where do you get your information from, how do you answer questions, why were you created, what kinds of problems can you solve, where do you go when you search for information, how do you decide what is important, what do you know about programming, science, history, languages, and technology, how do you explain difficult concepts to people, why do people use AI assistants, what happens when you do not know an answer, and why should someone trust the answers you provide?" + } + ] +} \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/claude/max-output.json b/distribution/tutorials/ai/llm-gateway/claude/max-output.json new file mode 100644 index 0000000000..0b1e1b3b21 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/claude/max-output.json @@ -0,0 +1,10 @@ +{ + "model": "claude-sonnet-4-0", + "max_tokens": 500, + "messages": [ + { + "role": "user", + "content": "What is your name?" + } + ] +} \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/claude/simple.json b/distribution/tutorials/ai/llm-gateway/claude/simple.json new file mode 100644 index 0000000000..bd6b974408 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/claude/simple.json @@ -0,0 +1,10 @@ +{ + "model": "claude-sonnet-4-0", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "Who are you?" + } + ] +} \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/openai/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/openai/10-Basic-LLM-Gateway.yaml new file mode 100644 index 0000000000..07ce7c4aff --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/openai/10-Basic-LLM-Gateway.yaml @@ -0,0 +1,26 @@ +# yaml-language-server: $schema=https://www.membrane-api.io/v7.2.1.json +# +# Tutorial: Basic LLM Gateway (OpenAI) +# +# Replace <> with your OpenAI API key. +# +# 1. Hello World +# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @simple.json http://localhost:2000/v1/responses +# +# 2. Exceed the input token limit +# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-input.json http://localhost:2000/v1/responses +# Returns an error because the request exceeds maxInputTokens. +# +# 3. Exceed the output token limit +# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-output.json http://localhost:2000/v1/responses +# Check the max_output_tokens field in the response and the Membrane log + +api: + port: 2000 + flow: + - llmGateway: + openai: {} + maxInputTokens: 100 + maxOutputTokens: 200 + target: + url: https://api.openai.com diff --git a/distribution/tutorials/ai/llm-gateway/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml similarity index 97% rename from distribution/tutorials/ai/llm-gateway/20-Sharing-API-Keys.yaml rename to distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml index ffaa8ca42d..16a0327022 100644 --- a/distribution/tutorials/ai/llm-gateway/20-Sharing-API-Keys.yaml +++ b/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml @@ -1,6 +1,6 @@ # yaml-language-server: $schema=https://www.membrane-api.io/v7.2.1.json # -# Tutorial: Sharing LLM API Keys +# Tutorial: Sharing LLM API Keys (OpenAI) # # Replace <> with your OpenAI API key. # diff --git a/distribution/tutorials/ai/llm-gateway/max-input.json b/distribution/tutorials/ai/llm-gateway/openai/max-input.json similarity index 100% rename from distribution/tutorials/ai/llm-gateway/max-input.json rename to distribution/tutorials/ai/llm-gateway/openai/max-input.json diff --git a/distribution/tutorials/ai/llm-gateway/max-output.json b/distribution/tutorials/ai/llm-gateway/openai/max-output.json similarity index 100% rename from distribution/tutorials/ai/llm-gateway/max-output.json rename to distribution/tutorials/ai/llm-gateway/openai/max-output.json diff --git a/distribution/tutorials/ai/llm-gateway/membrane.cmd b/distribution/tutorials/ai/llm-gateway/openai/membrane.cmd similarity index 100% rename from distribution/tutorials/ai/llm-gateway/membrane.cmd rename to distribution/tutorials/ai/llm-gateway/openai/membrane.cmd diff --git a/distribution/tutorials/ai/llm-gateway/membrane.sh b/distribution/tutorials/ai/llm-gateway/openai/membrane.sh similarity index 100% rename from distribution/tutorials/ai/llm-gateway/membrane.sh rename to distribution/tutorials/ai/llm-gateway/openai/membrane.sh diff --git a/distribution/tutorials/ai/llm-gateway/simple.json b/distribution/tutorials/ai/llm-gateway/openai/simple.json similarity index 100% rename from distribution/tutorials/ai/llm-gateway/simple.json rename to distribution/tutorials/ai/llm-gateway/openai/simple.json diff --git a/distribution/tutorials/ai/llm-gateway/wrong-model.json b/distribution/tutorials/ai/llm-gateway/openai/wrong-model.json similarity index 100% rename from distribution/tutorials/ai/llm-gateway/wrong-model.json rename to distribution/tutorials/ai/llm-gateway/openai/wrong-model.json From e0157c8b6f2f52cb213543c0fbfae39d99e25a72 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 20 May 2026 17:59:08 +0200 Subject: [PATCH 26/43] refactor: improve parameter documentation in `LLMErrorCreator` --- .../core/interceptor/ai/provider/LLMErrorCreator.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java index f81bda7445..ee06f9b7c3 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java @@ -16,9 +16,9 @@ public interface LLMErrorCreator { /** * - * @param maxTokens - * @param estimatedTokens - * @return + * @param maxTokens as configured + * @param estimatedTokens estimated number of input tokens + * @return Response error response */ Response inputTokensExceeded(long maxTokens, long estimatedTokens); } From 11491c6c337b4047db1d816b2a265a85e61724d7 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 21 May 2026 09:12:19 +0200 Subject: [PATCH 27/43] chore: add Apache 2.0 license headers to core files --- .../interceptor/mcp/ExchangeToolSupport.java | 29 +++++++++-------- .../core/interceptor/mcp/ExchangeUtils.java | 14 +++++++++ .../core/interceptor/mcp/MCPUtil.java | 14 +++++++++ .../interceptor/mcp/McpPayloadSanitizer.java | 14 +++++++++ .../interceptor/mcp/McpSchemaBuilder.java | 14 +++++++++ .../interceptor/mcp/McpSessionContext.java | 14 +++++++++ .../interceptor/mcp/McpSessionManager.java | 14 +++++++++ .../interceptor/mcp/McpToolDefinition.java | 14 +++++++++ .../core/interceptor/mcp/McpToolHandler.java | 14 +++++++++ .../core/interceptor/mcp/McpToolRegistry.java | 14 +++++++++ .../interceptor/mcp/MembraneMCPServer.java | 31 ++++++++++--------- .../membrane/core/jsonrpc/JSONRPCRequest.java | 15 ++++++++- .../core/jsonrpc/JSONRPCResponse.java | 16 +++++++++- .../membrane/core/jsonrpc/JSONRPCUtil.java | 14 +++++++++ .../membrane/core/mcp/MCPInitialize.java | 14 +++++++++ .../core/mcp/MCPInitializeResponse.java | 14 +++++++++ .../membrane/core/mcp/MCPInitialized.java | 14 +++++++++ .../membrane/core/mcp/MCPNotification.java | 14 +++++++++ .../predic8/membrane/core/mcp/MCPPing.java | 14 +++++++++ .../predic8/membrane/core/mcp/MCPRequest.java | 16 ++++++++-- .../membrane/core/mcp/MCPResponse.java | 14 +++++++++ .../membrane/core/mcp/MCPToolsCall.java | 14 +++++++++ .../core/mcp/MCPToolsCallResponse.java | 14 +++++++++ .../membrane/core/mcp/MCPToolsList.java | 14 +++++++++ .../core/mcp/MCPToolsListResponse.java | 14 +++++++++ .../mcp/MembraneMCPServerTest.java | 14 +++++++++ .../core/jsonrpc/JSONRPCRequestTest.java | 14 +++++++++ .../core/jsonrpc/JSONRPCResponseTest.java | 14 +++++++++ .../membrane/core/mcp/MCPInitializeTest.java | 14 +++++++++ 29 files changed, 411 insertions(+), 32 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeToolSupport.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeToolSupport.java index 319636fe29..6bfb14c5c8 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeToolSupport.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeToolSupport.java @@ -1,29 +1,32 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.predic8.membrane.core.exchange.AbstractExchange; -import com.predic8.membrane.core.interceptor.mcp.MCPUtil.InvalidToolArgumentsException; +import com.predic8.membrane.core.interceptor.mcp.MCPUtil.*; import com.predic8.membrane.core.mcp.MCPToolsCall; import com.predic8.membrane.core.mcp.MCPToolsCallResponse; import org.jetbrains.annotations.Nullable; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.UUID; +import java.util.*; import static com.predic8.membrane.core.interceptor.mcp.ExchangeUtils.matchesExchangeFilter; -import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.getOptionalBooleanArgument; -import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.getOptionalIntArgument; -import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.getOptionalSizeArgument; -import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.getOptionalStringArgument; -import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.getRequiredLongArgument; -import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.rejectUnexpectedArguments; import static com.predic8.membrane.core.interceptor.mcp.McpSchemaBuilder.integer; import static com.predic8.membrane.core.interceptor.mcp.McpSchemaBuilder.string; import static java.lang.Integer.MAX_VALUE; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeUtils.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeUtils.java index 274e150d3d..5e3c02c846 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeUtils.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeUtils.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import com.predic8.membrane.core.exchange.AbstractExchange; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPUtil.java index 148965c53e..b3c5b866df 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPUtil.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import com.predic8.membrane.core.exchange.AbstractExchange; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpPayloadSanitizer.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpPayloadSanitizer.java index addf83e551..dd9b291c9a 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpPayloadSanitizer.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpPayloadSanitizer.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import com.predic8.membrane.core.http.Header; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSchemaBuilder.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSchemaBuilder.java index 14539ca1cc..d03b7d1b05 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSchemaBuilder.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSchemaBuilder.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import java.util.Collections; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSessionContext.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSessionContext.java index 615e22d543..0bff5dbcab 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSessionContext.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSessionContext.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import com.predic8.membrane.core.mcp.MCPInitialize.ClientInfo; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSessionManager.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSessionManager.java index e67394c480..d607e51b7f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSessionManager.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpSessionManager.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import com.predic8.membrane.core.mcp.MCPInitialize; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolDefinition.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolDefinition.java index 49d14cbb4f..f07b51b8fd 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolDefinition.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolDefinition.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import com.predic8.membrane.core.mcp.MCPToolsListResponse; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolHandler.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolHandler.java index b8bcf4acf3..58910c945c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolHandler.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolHandler.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import com.predic8.membrane.core.exchange.Exchange; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolRegistry.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolRegistry.java index 1dde3ef18a..0480bc715e 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolRegistry.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/McpToolRegistry.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import java.util.Collection; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MembraneMCPServer.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MembraneMCPServer.java index a22da0473f..4fc1bab70e 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MembraneMCPServer.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MembraneMCPServer.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import com.fasterxml.jackson.core.JsonProcessingException; @@ -9,14 +23,7 @@ import com.predic8.membrane.core.interceptor.mcp.MCPUtil.InvalidToolArgumentsException; import com.predic8.membrane.core.jsonrpc.JSONRPCRequest; import com.predic8.membrane.core.jsonrpc.JSONRPCResponse; -import com.predic8.membrane.core.mcp.MCPInitialize; -import com.predic8.membrane.core.mcp.MCPInitializeResponse; -import com.predic8.membrane.core.mcp.MCPInitialized; -import com.predic8.membrane.core.mcp.MCPPing; -import com.predic8.membrane.core.mcp.MCPToolsCall; -import com.predic8.membrane.core.mcp.MCPToolsCallResponse; -import com.predic8.membrane.core.mcp.MCPToolsList; -import com.predic8.membrane.core.mcp.MCPToolsListResponse; +import com.predic8.membrane.core.mcp.*; import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,13 +43,7 @@ import static com.predic8.membrane.core.interceptor.mcp.McpSessionContext.McpSessionState.INITIALIZED; import static com.predic8.membrane.core.interceptor.mcp.McpSessionContext.McpSessionState.READY; import static com.predic8.membrane.core.jsonrpc.JSONRPCRequest.parse; -import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.ERR_INTERNAL_ERROR; -import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.ERR_INVALID_PARAMS; -import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.ERR_INVALID_REQUEST; -import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.ERR_METHOD_NOT_FOUND; -import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.ERR_PARSE_ERROR; -import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.error; -import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.success; +import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.*; /** * @description MCP Server for Membrane. It allows querying Membrane's internal state and operation from an LLM diff --git a/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCRequest.java b/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCRequest.java index 65f3b3b86b..f9b802c163 100644 --- a/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCRequest.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.jsonrpc; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -15,7 +29,6 @@ import java.io.IOException; import java.io.InputStream; -import java.io.OutputStream; import java.util.List; import java.util.Map; import java.util.Objects; diff --git a/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCResponse.java b/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCResponse.java index ebe1a2513e..02d95507fb 100644 --- a/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCResponse.java @@ -1,10 +1,24 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.jsonrpc; -import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializerProvider; diff --git a/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCUtil.java b/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCUtil.java index 158fa9e087..922c391ef1 100644 --- a/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/jsonrpc/JSONRPCUtil.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.jsonrpc; import com.fasterxml.jackson.databind.JsonNode; diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitialize.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitialize.java index a8d94ca230..9aa79a1ca1 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitialize.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitialize.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.predic8.membrane.core.jsonrpc.JSONRPCRequest; diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitializeResponse.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitializeResponse.java index c3227bae7a..a49610a9ef 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitializeResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitializeResponse.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.fasterxml.jackson.annotation.JsonInclude; diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitialized.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitialized.java index 27d63ee716..f39698807a 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitialized.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPInitialized.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.predic8.membrane.core.jsonrpc.JSONRPCRequest; diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPNotification.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPNotification.java index eeefa8e421..3e106d3aa3 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPNotification.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPNotification.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.predic8.membrane.core.jsonrpc.JSONRPCRequest; diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPPing.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPPing.java index b7aa3c8ce1..3746a156ba 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPPing.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPPing.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.predic8.membrane.core.jsonrpc.JSONRPCRequest; diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPRequest.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPRequest.java index 1d47fcf638..059766bc63 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPRequest.java @@ -1,9 +1,21 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.predic8.membrane.core.jsonrpc.JSONRPCRequest; -import java.util.Objects; - import static java.util.Objects.requireNonNull; /** diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPResponse.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPResponse.java index f9a27bd1d6..faaba2786f 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPResponse.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.predic8.membrane.core.jsonrpc.JSONRPCResponse; diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsCall.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsCall.java index 1724b5022e..51055b2ed8 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsCall.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsCall.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.predic8.membrane.core.jsonrpc.JSONRPCRequest; diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsCallResponse.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsCallResponse.java index 026b032375..e221eced60 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsCallResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsCallResponse.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.fasterxml.jackson.annotation.*; diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsList.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsList.java index c0beba054c..dbada1138f 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsList.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsList.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.predic8.membrane.core.jsonrpc.JSONRPCRequest; diff --git a/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsListResponse.java b/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsListResponse.java index ad78539949..48d72fbef9 100644 --- a/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsListResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/mcp/MCPToolsListResponse.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.fasterxml.jackson.annotation.JsonInclude; diff --git a/core/src/test/java/com/predic8/membrane/core/interceptor/mcp/MembraneMCPServerTest.java b/core/src/test/java/com/predic8/membrane/core/interceptor/mcp/MembraneMCPServerTest.java index 33fa0184f4..b52986e458 100644 --- a/core/src/test/java/com/predic8/membrane/core/interceptor/mcp/MembraneMCPServerTest.java +++ b/core/src/test/java/com/predic8/membrane/core/interceptor/mcp/MembraneMCPServerTest.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.interceptor.mcp; import com.fasterxml.jackson.databind.JsonNode; diff --git a/core/src/test/java/com/predic8/membrane/core/jsonrpc/JSONRPCRequestTest.java b/core/src/test/java/com/predic8/membrane/core/jsonrpc/JSONRPCRequestTest.java index 935ae66c12..f5efd94d6e 100644 --- a/core/src/test/java/com/predic8/membrane/core/jsonrpc/JSONRPCRequestTest.java +++ b/core/src/test/java/com/predic8/membrane/core/jsonrpc/JSONRPCRequestTest.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.jsonrpc; import org.junit.jupiter.api.Test; diff --git a/core/src/test/java/com/predic8/membrane/core/jsonrpc/JSONRPCResponseTest.java b/core/src/test/java/com/predic8/membrane/core/jsonrpc/JSONRPCResponseTest.java index 97eed97b58..7c0a85e98c 100644 --- a/core/src/test/java/com/predic8/membrane/core/jsonrpc/JSONRPCResponseTest.java +++ b/core/src/test/java/com/predic8/membrane/core/jsonrpc/JSONRPCResponseTest.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.jsonrpc; import org.junit.jupiter.api.Test; diff --git a/core/src/test/java/com/predic8/membrane/core/mcp/MCPInitializeTest.java b/core/src/test/java/com/predic8/membrane/core/mcp/MCPInitializeTest.java index 02d6dadce4..0abbb75f66 100644 --- a/core/src/test/java/com/predic8/membrane/core/mcp/MCPInitializeTest.java +++ b/core/src/test/java/com/predic8/membrane/core/mcp/MCPInitializeTest.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.mcp; import com.predic8.membrane.core.jsonrpc.JSONRPCRequest; From 6485949dc5b29a88d312c0b08acd13111233c16a Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 21 May 2026 11:30:33 +0200 Subject: [PATCH 28/43] feat: add Google Gemini and enhance Claude tutorials with API key sharing and token limit examples - Added `10-Basic-LLM-Gateway.yaml` and `20-Sharing-API-Keys.yaml` tutorials for Google Gemini. - Enhanced Claude tutorials with improved key handling and token limit examples. - Introduced `GoogleErrorCreator` for detailed error handling in Google LLM Gateway. - Updated `LLMGatewayInterceptor` and token tracking logic to reflect effective max token handling. - Modified existing OpenAI and Claude examples for consistency and clarity. --- .../interceptor/ai/LLMGatewayInterceptor.java | 7 +- .../provider/claude/ClaudeErrorCreator.java | 67 +++++++++---- .../provider/google/GoogleErrorCreator.java | 98 +++++++++++++++++++ .../ai/provider/google/GoogleLLMRequest.java | 3 + .../ai/provider/google/GoogleLLMResponse.java | 4 +- .../ai/provider/google/GoogleProvider.java | 2 +- .../claude/10-Basic-LLM-Gateway.yaml | 4 +- .../claude/20-Sharing-API-Keys.yaml | 56 +++++++++++ .../ai/llm-gateway/claude/max-output.json | 2 +- .../ai/llm-gateway/claude/membrane.cmd | 24 +++++ .../ai/llm-gateway/claude/membrane.sh | 21 ++++ .../ai/llm-gateway/claude/wrong-model.json | 10 ++ .../google/10-Basic-LLM-Gateway.yaml | 27 +++++ .../google/20-Sharing-API-Keys.yaml | 56 +++++++++++ .../ai/llm-gateway/google/max-input.json | 11 +++ .../ai/llm-gateway/google/max-output.json | 14 +++ .../ai/llm-gateway/google/membrane.cmd | 24 +++++ .../ai/llm-gateway/google/membrane.sh | 21 ++++ .../ai/llm-gateway/google/simple.json | 11 +++ .../openai/20-Sharing-API-Keys.yaml | 7 +- .../ai/llm-gateway/openai/max-output.json | 2 +- 21 files changed, 441 insertions(+), 30 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java create mode 100644 distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml create mode 100644 distribution/tutorials/ai/llm-gateway/claude/membrane.cmd create mode 100755 distribution/tutorials/ai/llm-gateway/claude/membrane.sh create mode 100644 distribution/tutorials/ai/llm-gateway/claude/wrong-model.json create mode 100644 distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml create mode 100644 distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml create mode 100644 distribution/tutorials/ai/llm-gateway/google/max-input.json create mode 100644 distribution/tutorials/ai/llm-gateway/google/max-output.json create mode 100644 distribution/tutorials/ai/llm-gateway/google/membrane.cmd create mode 100755 distribution/tutorials/ai/llm-gateway/google/membrane.sh create mode 100644 distribution/tutorials/ai/llm-gateway/google/simple.json diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index a4ba53f7b8..41c536c39b 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -90,11 +90,12 @@ public Outcome handleRequest(Exchange exc) { // Check store limits if (store != null) { - var remaining = store.checkLimit(user, inputTokens, maxOutputTokens); + var effectiveMaxTokens = Math.min(aiReq.getRequestedMaxOutputTokens(), maxOutputTokens); + var remaining = store.checkLimit(user, inputTokens, effectiveMaxTokens); log.debug("User {} has {} remaining tokens left", user, remaining); if (remaining <= 0) { - log.info("Token limit exceeded. Remaining: {} input: {} maxOutput: {}", remaining, inputTokens, maxOutputTokens); - exc.setResponse(errorCreator.tokenLimitExceeded(inputTokens + maxOutputTokens, remaining, store.getRemainingResetTime())); + log.info("Token limit exceeded. Remaining: {} input: {} maxOutput: {}", remaining, inputTokens, effectiveMaxTokens); + exc.setResponse(errorCreator.tokenLimitExceeded(inputTokens + effectiveMaxTokens, remaining, store.getRemainingResetTime())); return RETURN; } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java index 9cf25f7ae5..506b6d2083 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java @@ -7,46 +7,79 @@ import java.util.Collection; import java.util.UUID; -import static com.predic8.membrane.core.http.Response.badRequest; +import static com.predic8.membrane.core.http.Header.WWW_AUTHENTICATE; +import static com.predic8.membrane.core.http.Response.*; public class ClaudeErrorCreator implements LLMErrorCreator { - // Claude error types + private static final String INVALID_REQUEST_ERROR = "invalid_request_error"; + private static final String AUTHENTICATION_ERROR = "authentication_error"; private static final String RATE_LIMIT_ERROR = "rate_limit_error"; @Override public Response invalidRequestError(String message) { - return null; + return badRequest() + .json(error(INVALID_REQUEST_ERROR, message)) + .build(); } @Override public Response tokenLimitExceeded(long tokenRequired, long tokenRemaining, long tokenResetInSeconds) { - return null; + long visibleRemaining = Math.max(0, tokenRemaining); + + return statusCode(429) + .json(error( + RATE_LIMIT_ERROR, + """ + Token rate limit exceeded. + Request requires %d tokens but only %d remain. + Retry after %d seconds. + """.formatted(tokenRequired, visibleRemaining, tokenResetInSeconds).trim() + )) + .build(); } @Override public Response modelNotAllowed(String model, Collection allowedModels) { - return null; + return badRequest() + .json(error( + INVALID_REQUEST_ERROR, + "Model '%s' is not allowed. Allowed models: %s." + .formatted(model, String.join(", ", allowedModels)) + )) + .build(); } @Override public Response authenticationFailed() { - return null; + return unauthorized() + .header(WWW_AUTHENTICATE, "Bearer") + .json(error(AUTHENTICATION_ERROR, "Invalid bearer token")) + .build(); } @Override public Response inputTokensExceeded(long maxTokens, long estimatedTokens) { - var json = ClaudeErrorResponse.builder().error( - ClaudeError.builder().type(RATE_LIMIT_ERROR) - .message(""" - prompt is too long: - %d tokens > %d maximum - """.formatted(estimatedTokens, maxTokens).trim()) - ).requestId("membrane_" + UUID.randomUUID()) - .toJson(); - return badRequest() - .json(json) + .json(error( + INVALID_REQUEST_ERROR, + """ + prompt is too long: + %d tokens > %d maximum + """.formatted(estimatedTokens, maxTokens).trim() + )) .build(); } -} + + private String error(String type, String message) { + return ClaudeErrorResponse.builder() + .type("error") + .error( + ClaudeError.builder() + .type(type) + .message(message) + ) + .requestId("membrane_" + UUID.randomUUID()) + .toJson(); + } +} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java new file mode 100644 index 0000000000..281a314594 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java @@ -0,0 +1,98 @@ +package com.predic8.membrane.core.interceptor.ai.provider.google; + +import com.predic8.membrane.core.http.Response; +import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMErrorCreator; + +import java.util.Collection; + +import static com.predic8.membrane.core.http.Header.WWW_AUTHENTICATE; +import static com.predic8.membrane.core.http.Response.*; + +public class GoogleErrorCreator extends AbstractLLMErrorCreator { + + @Override + public Response invalidRequestError(String message) { + return badRequest().json( + envelope(400, message, "INVALID_ARGUMENT") + ).build(); + } + + public Response tokenLimitExceeded(long tokenRequired, + long tokenRemaining, + long tokenResetInSeconds) { + + return statusCode(429).json( + envelope( + 429, + """ + Token rate limit exceeded. + Request requires %d tokens but only %d remain. + Retry after %d seconds. + """ + .formatted(tokenRequired, tokenRemaining, tokenResetInSeconds) + .trim(), + "RESOURCE_EXHAUSTED" + ) + ).build(); + } + + public Response modelNotAllowed(String model, + Collection allowedModels) { + + return badRequest().json( + envelope( + 400, + "Model '%s' is not allowed. Allowed models: %s." + .formatted(model, String.join(", ", allowedModels)), + "INVALID_ARGUMENT" + ) + ).build(); + } + + public Response authenticationFailed() { + return unauthorized() + .header(WWW_AUTHENTICATE, "Bearer") + .json( + envelope( + 401, + "Invalid API key.", + "UNAUTHENTICATED" + ) + ).build(); + } + + public Response inputTokensExceeded(long maxTokens, + long estimatedTokens) { + + return badRequest().json( + envelope( + 400, + """ + The input token count (%d) exceeds the maximum allowed (%d). + """ + .formatted(estimatedTokens, maxTokens) + .trim(), + "INVALID_ARGUMENT" + ) + ).build(); + } + + private String envelope(int code, + String message, + String status) { + + return createJson(new ErrorEnvelope( + new ErrorBody(code, message, status) + )); + } + + private record ErrorEnvelope(ErrorBody error) { + } + + private record ErrorBody( + int code, + String message, + String status + ) { + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java index aff431819d..da0b174465 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java @@ -7,6 +7,9 @@ public class GoogleLLMRequest extends AbstractLLMRequest { + /** + * x-goog-api-key is correct it is not google + */ public static final String X_GOOG_API_KEY = "x-goog-api-key"; public GoogleLLMRequest(Exchange exchange) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMResponse.java index 5b0adb2144..db04ae85df 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMResponse.java @@ -20,7 +20,9 @@ public Usage getUsage() { var usage = json.path("usageMetadata"); int inputTokens = usage.path("promptTokenCount").asInt(0); - int outputTokens = usage.path("candidatesTokenCount").asInt(0); + int thoughtsTokens = usage.path("thoughtsTokenCount").asInt(0); + int candidatesTokenCount = usage.path("candidatesTokenCount").asInt(0); + int outputTokens = thoughtsTokens + candidatesTokenCount; int totalTokens = usage.path("totalTokenCount").asInt(inputTokens + outputTokens); return new Usage( diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java index 8c3eab8b41..4ee8860f89 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java @@ -28,6 +28,6 @@ public LLMResponse getLLMResponse(Exchange request, Consumer postPr @Override public LLMErrorCreator getErrorCreator() { - return null; + return new GoogleErrorCreator(); } } diff --git a/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml index 5a9704c845..8931dfd28a 100644 --- a/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml +++ b/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml @@ -9,11 +9,11 @@ # Check the response and the Membrane logs. # # 2. Exceed the input token limit -# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-input.json http://localhost:2000/v1/responses +# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-input.json http://localhost:2000/v1/messages # Returns an error because the request exceeds maxInputTokens. # # 3. Exceed the output token limit -# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-output.json http://localhost:2000/v1/responses +# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-output.json http://localhost:2000/v1/messages # Check the Membrane log for limiting max tokens to 200 api: diff --git a/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml new file mode 100644 index 0000000000..6cb3f0c0b3 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml @@ -0,0 +1,56 @@ +# yaml-language-server: $schema=https://www.membrane-api.io/v7.2.1.json +# +# Tutorial: Sharing LLM API Keys (Claude) +# +# Replace <> with your Claude API key. +# +# Requests: +# +# 1. Hello AI +# curl -v -H "Content-Type: application/json" -H "x-api-key: abc123" -H "anthropic-version: 2023-06-01" -d @simple.json http://localhost:2000/v1/messages +# Check: Successful response +# +# 2. Token Limit Exceeded +# Repeat the previous request until you receive: 429 Token Limit Exceeded +# User alice is blocked after the limit is exceeded. Bob should still be able to send requests. +# +# 3. Wrong Model +# curl -v -H "Content-Type: application/json" -H "x-api-key: abc123" -H "anthropic-version: 2023-06-01" -d @wrong-model.json http://localhost:2000/v1/messages +# Check: Error response +# +# 4. Max. Input Tokens Exceeded +# curl -v -H "Content-Type: application/json" -H "x-api-key: abc123" -H "anthropic-version: 2023-06-01" -d @max-input.json http://localhost:2000/v1/messages +# Check: Error response +# +# 5. Requested Max. Output Tokens Exceeded +# curl -v -H "Content-Type: application/json" -H "x-api-key: abc123" -H "anthropic-version: 2023-06-01" -d @max-output.json http://localhost:2000/v1/messages +# Check Membrane log: totalTokens should not exceed 200 even it was requested in max-output.json + +api: + port: 2000 + flow: + - llmGateway: + claude: {} + apiKey: <> + # Limits per request + maxInputTokens: 100 + maxOutputTokens: 200 + models: + - claude-sonnet-4-0 + - claude-opus-4-0 + - claude-haiku-3-5 + simpleStore: + # User-facing API keys for the LLM Gateway + users: + - name: alice + apiKey: abc123 + tokens: 250 # Token limit for alice + - name: bob + apiKey: qwertz + tokens: 10000 + # Time in seconds after which the token limit is reset + limitResetPeriod: 60 + - request: + - log: {} + target: + url: https://api.anthropic.com diff --git a/distribution/tutorials/ai/llm-gateway/claude/max-output.json b/distribution/tutorials/ai/llm-gateway/claude/max-output.json index 0b1e1b3b21..b3746f34c6 100644 --- a/distribution/tutorials/ai/llm-gateway/claude/max-output.json +++ b/distribution/tutorials/ai/llm-gateway/claude/max-output.json @@ -4,7 +4,7 @@ "messages": [ { "role": "user", - "content": "What is your name?" + "content": "Explain in detail who you are?" } ] } \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/claude/membrane.cmd b/distribution/tutorials/ai/llm-gateway/claude/membrane.cmd new file mode 100644 index 0000000000..8d2d64e9cf --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/claude/membrane.cmd @@ -0,0 +1,24 @@ +@echo off +setlocal EnableExtensions + +set "SCRIPT_DIR=%~dp0" +if "%SCRIPT_DIR:~-1%"=="\" set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%" + +set "dir=%SCRIPT_DIR%" + +:search_up +if exist "%dir%\LICENSE.txt" if exist "%dir%\scripts\run-membrane.cmd" goto found +for %%A in ("%dir%\..") do set "next=%%~fA" +if /I "%next%"=="%dir%" goto notfound +set "dir=%next%" +goto search_up + +:found +set "MEMBRANE_HOME=%dir%" +set "MEMBRANE_CALLER_DIR=%SCRIPT_DIR%" +call "%MEMBRANE_HOME%\scripts\run-membrane.cmd" %* +exit /b %ERRORLEVEL% + +:notfound +>&2 echo Could not locate Membrane root. Ensure directory structure is correct. +exit /b 1 diff --git a/distribution/tutorials/ai/llm-gateway/claude/membrane.sh b/distribution/tutorials/ai/llm-gateway/claude/membrane.sh new file mode 100755 index 0000000000..195dae51ec --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/claude/membrane.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# Default: ./proxies.xml (next to this script); fallback -> $MEMBRANE_HOME/conf/proxies.xml +# JAVA_OPTS: relative -D paths are auto-resolved against $MEMBRANE_HOME (absolute/URI unchanged). +# Examples: +# export JAVA_OPTS='-Dlog4j.configurationFile=examples/logging/access/log4j2_access.xml' +# export JAVA_OPTS='-Dlog4j.configurationFile=/abs/path/log4j2.xml' + +SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd -P) + +dir="$SCRIPT_DIR" +while [ "$dir" != "/" ]; do + if [ -f "$dir/LICENSE.txt" ] && [ -f "$dir/scripts/run-membrane.sh" ]; then + export MEMBRANE_HOME="$dir" + export MEMBRANE_CALLER_DIR="$SCRIPT_DIR" + exec sh "$dir/scripts/run-membrane.sh" "$@" + fi + dir=$(dirname "$dir") +done + +echo "Could not locate Membrane root. Ensure directory structure is correct." >&2 +exit 1 \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/claude/wrong-model.json b/distribution/tutorials/ai/llm-gateway/claude/wrong-model.json new file mode 100644 index 0000000000..d149716e51 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/claude/wrong-model.json @@ -0,0 +1,10 @@ +{ + "model": "gpt-5", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "Who are you?" + } + ] +} \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml new file mode 100644 index 0000000000..ce7ab99ef6 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml @@ -0,0 +1,27 @@ +# yaml-language-server: $schema=https://www.membrane-api.io/v7.2.1.json +# +# Tutorial: Basic LLM Gateway (Google Gemini) +# +# Replace <> with your OpenAI API key. +# +# 1. Hello World +# curl -v -H "Content-Type: application/json" -H "x-goog-api-key: <>" -d @simple.json http://localhost:2000/v1beta/models/gemini-2.5-flash:generateContent +# Check the response and the Membrane logs. +# +# 2. Exceed the input token limit +# curl -v -H "Content-Type: application/json" -H "x-goog-api-key: <>" -d @max-input.json http://localhost:2000/v1beta/models/gemini-2.5-flash:generateContent +# Returns an error because the request exceeds maxInputTokens. +# +# 3. Exceed the output token limit +# curl -v -H "Content-Type: application/json" -H "x-goog-api-key: <>" -d @max-output.json http://localhost:2000/v1beta/models/gemini-2.5-flash:generateContent +# Check the Membrane log for limiting max tokens to 200 + +api: + port: 2000 + flow: + - llmGateway: + google: {} + maxInputTokens: 100 + maxOutputTokens: 200 + target: + url: https://generativelanguage.googleapis.com \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml new file mode 100644 index 0000000000..89db5cd5b7 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml @@ -0,0 +1,56 @@ +# yaml-language-server: $schema=https://www.membrane-api.io/v7.2.1.json +# +# Tutorial: Sharing LLM API Keys (Google Gemini) +# +# Replace <> with your Gemini API key. +# +# Requests: +# +# 1. Hello AI +# curl -v -H "Content-Type: application/json" -H "x-goog-api-key: abc123" -d @simple.json http://localhost:2000/v1beta/models/gemini-2.5-flash:generateContent +# Check: Successful response +# +# 2. Token Limit Exceeded +# Repeat the previous request until you receive: 429 Token Limit Exceeded +# User alice is blocked after the limit is exceeded. Bob should still be able to send requests. +# +# 3. Wrong Model +# curl -v -H "Content-Type: application/json" -H "x-goog-api-key: abc123" -d @simple.json http://localhost:2000/v1beta/models/gpt-5:generateContent +# Check: Error response +# +# 4. Max. Input Tokens Exceeded +# curl -v -H "Content-Type: application/json" -H "x-goog-api-key: abc123" -d @max-input.json http://localhost:2000/v1beta/models/gemini-2.5-flash:generateContent +# Check: Error response +# +# 5. Requested Max. Output Tokens Exceeded +# curl -v -H "Content-Type: application/json" -H "x-goog-api-key: abc123" -d @max-output.json http://localhost:2000/v1beta/models/gemini-2.5-flash:generateContent +# Check Membrane log: totalTokens should not exceed 200 even it was requested in max-output.json + +api: + port: 2000 + flow: + - llmGateway: + google: {} + apiKey: <> + # Limits per request + maxInputTokens: 100 + maxOutputTokens: 200 + models: + - gemini-2.5-pro + - gemini-2.5-flash + - gemini-2.5-flash-lite + - gemini-2.0-flash + - gemini-2.0-flash-lite + simpleStore: + # User-facing API keys for the LLM Gateway + users: + - name: alice + apiKey: abc123 + tokens: 500 # Token limit for alice + - name: bob + apiKey: qwertz + tokens: 10000 + # Time in seconds after which the token limit is reset + limitResetPeriod: 60 + target: + url: https://generativelanguage.googleapis.com diff --git a/distribution/tutorials/ai/llm-gateway/google/max-input.json b/distribution/tutorials/ai/llm-gateway/google/max-input.json new file mode 100644 index 0000000000..017608297f --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/google/max-input.json @@ -0,0 +1,11 @@ +{ + "contents": [ + { + "parts": [ + { + "text": "Who are you, where do you get your information from, how do you answer questions, why were you created, what kinds of problems can you solve, where do you go when you search for information, how do you decide what is important, what do you know about programming, science, history, languages, and technology, how do you explain difficult concepts to people, why do people use AI assistants, what happens when you do not know an answer, and why should someone trust the answers you provide?" + } + ] + } + ] +} \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/google/max-output.json b/distribution/tutorials/ai/llm-gateway/google/max-output.json new file mode 100644 index 0000000000..615c6db3a0 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/google/max-output.json @@ -0,0 +1,14 @@ +{ + "contents": [ + { + "parts": [ + { + "text": "Explain in detail who you are?" + } + ] + } + ], + "generationConfig": { + "maxOutputTokens": 500 + } +} \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/google/membrane.cmd b/distribution/tutorials/ai/llm-gateway/google/membrane.cmd new file mode 100644 index 0000000000..8d2d64e9cf --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/google/membrane.cmd @@ -0,0 +1,24 @@ +@echo off +setlocal EnableExtensions + +set "SCRIPT_DIR=%~dp0" +if "%SCRIPT_DIR:~-1%"=="\" set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%" + +set "dir=%SCRIPT_DIR%" + +:search_up +if exist "%dir%\LICENSE.txt" if exist "%dir%\scripts\run-membrane.cmd" goto found +for %%A in ("%dir%\..") do set "next=%%~fA" +if /I "%next%"=="%dir%" goto notfound +set "dir=%next%" +goto search_up + +:found +set "MEMBRANE_HOME=%dir%" +set "MEMBRANE_CALLER_DIR=%SCRIPT_DIR%" +call "%MEMBRANE_HOME%\scripts\run-membrane.cmd" %* +exit /b %ERRORLEVEL% + +:notfound +>&2 echo Could not locate Membrane root. Ensure directory structure is correct. +exit /b 1 diff --git a/distribution/tutorials/ai/llm-gateway/google/membrane.sh b/distribution/tutorials/ai/llm-gateway/google/membrane.sh new file mode 100755 index 0000000000..195dae51ec --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/google/membrane.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# Default: ./proxies.xml (next to this script); fallback -> $MEMBRANE_HOME/conf/proxies.xml +# JAVA_OPTS: relative -D paths are auto-resolved against $MEMBRANE_HOME (absolute/URI unchanged). +# Examples: +# export JAVA_OPTS='-Dlog4j.configurationFile=examples/logging/access/log4j2_access.xml' +# export JAVA_OPTS='-Dlog4j.configurationFile=/abs/path/log4j2.xml' + +SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd -P) + +dir="$SCRIPT_DIR" +while [ "$dir" != "/" ]; do + if [ -f "$dir/LICENSE.txt" ] && [ -f "$dir/scripts/run-membrane.sh" ]; then + export MEMBRANE_HOME="$dir" + export MEMBRANE_CALLER_DIR="$SCRIPT_DIR" + exec sh "$dir/scripts/run-membrane.sh" "$@" + fi + dir=$(dirname "$dir") +done + +echo "Could not locate Membrane root. Ensure directory structure is correct." >&2 +exit 1 \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/google/simple.json b/distribution/tutorials/ai/llm-gateway/google/simple.json new file mode 100644 index 0000000000..3bf6c67b2e --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/google/simple.json @@ -0,0 +1,11 @@ +{ + "contents": [ + { + "parts": [ + { + "text": "Who are you?" + } + ] + } + ] +} \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml index 16a0327022..b0231ca3d8 100644 --- a/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml +++ b/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml @@ -2,12 +2,12 @@ # # Tutorial: Sharing LLM API Keys (OpenAI) # -# Replace <> with your OpenAI API key. +# Replace <> with your OpenAI API key. # # Requests: # # 1. Hello AI -# curl -H "Content-Type: application/json" -H "Authorization: Bearer abc123" -d @simple.json http://localhost:2000/v1/responses +# curl -v -H "Content-Type: application/json" -H "Authorization: Bearer abc123" -d @simple.json http://localhost:2000/v1/responses # Check: Successful response # # 2. Token Limit Exceeded @@ -30,8 +30,7 @@ api: port: 2000 flow: - llmGateway: - # Replace <> with your OpenAI API key - apiKey: <> + apiKey: <> # Limits per request maxInputTokens: 100 maxOutputTokens: 200 diff --git a/distribution/tutorials/ai/llm-gateway/openai/max-output.json b/distribution/tutorials/ai/llm-gateway/openai/max-output.json index 65a63165f2..cc7e04017f 100644 --- a/distribution/tutorials/ai/llm-gateway/openai/max-output.json +++ b/distribution/tutorials/ai/llm-gateway/openai/max-output.json @@ -1,5 +1,5 @@ { "model": "gpt-5-nano", - "input": "What is your name?", + "input": "Explain in detail who you are?", "max_output_tokens": 500 } \ No newline at end of file From 38acaa360cee0e2578879c982364886c7dc84271 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 21 May 2026 14:32:25 +0200 Subject: [PATCH 29/43] feat: add AI LLM Gateway tests for Claude, OpenAI, and Google Gemini tutorials - Introduced `AbstractAiTutorialTest` base class and provider-specific extensions for easier test creation. - Added integration tests for basic gateway setups and API key sharing for Claude, OpenAI, and Google Gemini. - Simulated upstream mock APIs to enable testing token limits, key handling, and input/output transformations. --- .../ai/provider/google/GoogleLLMRequest.java | 5 + .../interceptor/mcp/ExchangeToolSupport.java | 2 +- .../ai/llmgateway/AbstractAiTutorialTest.java | 151 ++++++++++++ .../BasicClaudeLLMGatewayTutorialTest.java | 114 +++++++++ .../claude/SharingApiKeysTutorialTest.java | 223 ++++++++++++++++++ .../google/AbstractGoogleTutorialTest.java | 58 +++++ .../BasicGoogleLLMGatewayTutorialTest.java | 109 +++++++++ .../SharingApiKeysGoogleTutorialTest.java | 219 +++++++++++++++++ .../openai/AbstractOpenAiTutorialTest.java | 61 +++++ .../BasicOpenAiLLMGatewayTutorialTest.java | 105 +++++++++ .../SharingApiKeysOpenAiTutorialTest.java | 208 ++++++++++++++++ 11 files changed, 1254 insertions(+), 1 deletion(-) create mode 100644 distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java create mode 100644 distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/BasicClaudeLLMGatewayTutorialTest.java create mode 100644 distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/SharingApiKeysTutorialTest.java create mode 100644 distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/AbstractGoogleTutorialTest.java create mode 100644 distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/BasicGoogleLLMGatewayTutorialTest.java create mode 100644 distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/SharingApiKeysGoogleTutorialTest.java create mode 100644 distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/AbstractOpenAiTutorialTest.java create mode 100644 distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/BasicOpenAiLLMGatewayTutorialTest.java create mode 100644 distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java index da0b174465..07da55b089 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java @@ -34,7 +34,12 @@ public String getModel() { var modelPart = uri.substring(modelsIndex + "/models/".length()); + // Support both ':' and URL-encoded '%3A' / '%3a' as separator before the action suffix + // (e.g. ':generateContent' or '%3AgenerateContent'). int colonIndex = modelPart.indexOf(':'); + if (colonIndex < 0) { + colonIndex = modelPart.toLowerCase().indexOf("%3a"); + } if (colonIndex >= 0) { return modelPart.substring(0, colonIndex); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeToolSupport.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeToolSupport.java index 6bfb14c5c8..ed9f1608d8 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeToolSupport.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/ExchangeToolSupport.java @@ -17,7 +17,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.predic8.membrane.core.exchange.AbstractExchange; -import com.predic8.membrane.core.interceptor.mcp.MCPUtil.*; import com.predic8.membrane.core.mcp.MCPToolsCall; import com.predic8.membrane.core.mcp.MCPToolsCallResponse; import org.jetbrains.annotations.Nullable; @@ -27,6 +26,7 @@ import java.util.*; import static com.predic8.membrane.core.interceptor.mcp.ExchangeUtils.matchesExchangeFilter; +import static com.predic8.membrane.core.interceptor.mcp.MCPUtil.*; import static com.predic8.membrane.core.interceptor.mcp.McpSchemaBuilder.integer; import static com.predic8.membrane.core.interceptor.mcp.McpSchemaBuilder.string; import static java.lang.Integer.MAX_VALUE; diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java new file mode 100644 index 0000000000..fb5a37a07c --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java @@ -0,0 +1,151 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.llmgateway; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.AbstractInterceptor; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.interceptor.flow.ReturnInterceptor; +import com.predic8.membrane.core.interceptor.templating.StaticInterceptor; +import com.predic8.membrane.core.proxies.ServiceProxy; +import com.predic8.membrane.core.proxies.ServiceProxyKey; +import com.predic8.membrane.core.router.DefaultRouter; +import com.predic8.membrane.examples.util.DistributionExtractingTestcase; +import com.predic8.membrane.examples.util.Process2; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import java.util.function.Consumer; + +/** + * Base class for AI tutorial tests. Starts a local Membrane mock of the upstream LLM API + * so tests run without a real API key and without network access to the LLM provider. + * + *

The tutorial YAML's {@code target.url} is rewritten to point at the mock server + * before Membrane starts. Subclasses override {@link #getTutorialDir()} and + * {@link #getTutorialYaml()} to select the tutorial under test. + * + *

JUnit 5 lifecycle ordering guarantees that {@code DistributionExtractingTestcase.init()} + * (superclass {@code @BeforeEach}) runs first and sets {@code baseDir}, allowing + * {@link #startGateway()} to use {@code replaceInFile2()} safely. + */ +public abstract class AbstractAiTutorialTest extends DistributionExtractingTestcase { + + protected static final int MOCK_LLM_PORT = 3100; + + protected Process2 process; + protected volatile String lastRequestBody; + protected volatile String lastRequestApiKey; + + private DefaultRouter mockRouter; + + protected abstract String getTutorialDir(); + protected abstract String getTutorialYaml(); + + @Override + protected String getExampleDirName() { + return "../tutorials/%s".formatted(getTutorialDir()); + } + + @Override + protected String getParameters() { + return "-c %s".formatted(getTutorialYaml()); + } + + /** + * Runs after {@code DistributionExtractingTestcase.init()} sets {@code baseDir}. + * Starts the mock, patches the YAML, then starts Membrane. + */ + @BeforeEach + void startGateway() throws Exception { + startMockLlmApi(); + replaceInFile2(getTutorialYaml(), getUpstreamApiUrl(), mockApiUrl()); + process = startServiceProxyScript(); + } + + @AfterEach + void stopGateway() { + if (process != null) + process.killScript(); + if (mockRouter != null) + mockRouter.stop(); + } + + /** + * The upstream API URL used in the tutorial YAML (to be replaced by the mock URL). + */ + protected String getUpstreamApiUrl() { + return "https://api.anthropic.com"; + } + + protected String mockApiUrl() { + return "http://localhost:" + MOCK_LLM_PORT; + } + + /** + * The HTTP header name from which the upstream API key is read when capturing + * requests in the mock. Defaults to {@code "x-api-key"} (Claude). Override to + * {@code "authorization"} for OpenAI or {@code "x-goog-api-key"} for Google. + */ + protected String getApiKeyHeader() { + return "x-api-key"; + } + + private void startMockLlmApi() throws Exception { + var si = new StaticInterceptor(); + si.setSrc(mockResponse()); + si.setContentType("application/json"); + + var sp = new ServiceProxy(new ServiceProxyKey(MOCK_LLM_PORT), null, 0); + sp.getFlow().add(new BodyCaptureInterceptor( + body -> lastRequestBody = body, + apiKey -> lastRequestApiKey = apiKey, + getApiKeyHeader())); + sp.getFlow().add(si); + sp.getFlow().add(new ReturnInterceptor()); + + mockRouter = new DefaultRouter(); + mockRouter.add(sp); + mockRouter.start(); + } + + private static class BodyCaptureInterceptor extends AbstractInterceptor { + + private final Consumer bodySink; + private final Consumer apiKeySink; + private final String apiKeyHeader; + + BodyCaptureInterceptor(Consumer bodySink, Consumer apiKeySink, String apiKeyHeader) { + this.bodySink = bodySink; + this.apiKeySink = apiKeySink; + this.apiKeyHeader = apiKeyHeader; + } + + @Override + public Outcome handleRequest(Exchange exc) { + bodySink.accept(exc.getRequest().getBodyAsStringDecoded()); + apiKeySink.accept(exc.getRequest().getHeader().getFirstValue(apiKeyHeader)); + return Outcome.CONTINUE; + } + } + + protected String mockResponse() { + return """ + {"id":"msg_mock","type":"message","role":"assistant",\ + "content":[{"type":"text","text":"I am a mock."}],\ + "model":"claude-sonnet-4-0","stop_reason":"end_turn",\ + "usage":{"input_tokens":10,"output_tokens":5}}"""; + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/BasicClaudeLLMGatewayTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/BasicClaudeLLMGatewayTutorialTest.java new file mode 100644 index 0000000000..3cde3fa976 --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/BasicClaudeLLMGatewayTutorialTest.java @@ -0,0 +1,114 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.llmgateway.claude; + +import com.predic8.membrane.tutorials.ai.llmgateway.AbstractAiTutorialTest; +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static io.restassured.RestAssured.given; +import static io.restassured.path.json.JsonPath.from; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +/** + * Integration test for {@code distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml}. + * + *

The tutorial configures a Claude LLM gateway with: + *

    + *
  • {@code maxInputTokens: 100} — requests whose estimated input exceeds 100 tokens are rejected
  • + *
  • {@code maxOutputTokens: 200} — {@code max_tokens} in the forwarded request is capped to 200
  • + *
+ * + *

The upstream Anthropic API is replaced by a local mock server so no real API key is needed. + */ +public class BasicClaudeLLMGatewayTutorialTest extends AbstractAiTutorialTest { + + @Override + protected String getTutorialDir() { + return "ai/llm-gateway/claude"; + } + + @Override + protected String getTutorialYaml() { + return "10-Basic-LLM-Gateway.yaml"; + } + + /** + * A request within the token limits is forwarded to the upstream and its response is returned. + */ + @Test + void simpleRequestIsForwarded() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", "test-key") + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(200) + .body("type", equalTo("message")) + .body("content[0].type", equalTo("text")); + // @formatter:on + } + + /** + * A request whose message content exceeds {@code maxInputTokens} (100) is rejected by the + * gateway before reaching the upstream. The response uses the Claude error format. + */ + @Test + void inputTokenLimitExceededIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", "test-key") + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("max-input.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(400) + .body("type", equalTo("error")) + .body("error.type", equalTo("invalid_request_error")) + .body("error.message", containsString("tokens")); + // @formatter:on + } + + /** + * When the request asks for more output tokens than {@code maxOutputTokens} (200) allows, + * the gateway rewrites {@code max_tokens} to 200 before forwarding to the upstream. + * The mock captures the forwarded body so we can verify the value was actually capped. + */ + @Test + void outputTokensAreCappedBeforeForwarding() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", "test-key") + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("max-output.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(200); + // @formatter:on + + assertThat(from(lastRequestBody).getInt("max_tokens"), equalTo(200)); + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/SharingApiKeysTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/SharingApiKeysTutorialTest.java new file mode 100644 index 0000000000..33174d6ac8 --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/SharingApiKeysTutorialTest.java @@ -0,0 +1,223 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.llmgateway.claude; + +import com.predic8.membrane.tutorials.ai.llmgateway.AbstractAiTutorialTest; +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static io.restassured.RestAssured.given; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; + +/** + * Integration tests for + * {@code distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml}. + * + *

The tutorial demonstrates sharing a single upstream API key between multiple users, + * each identified by their own gateway key and subject to individual token budgets: + *

    + *
  • alice — key {@code abc123}, budget 250 tokens
  • + *
  • bob — key {@code qwertz}, budget 10 000 tokens
  • + *
+ * Additional gateway limits: {@code maxInputTokens=100}, {@code maxOutputTokens=200}, + * allowed models: {@code claude-sonnet-4-0}, {@code claude-opus-4-0}, {@code claude-haiku-3-5}. + */ +public class SharingApiKeysTutorialTest extends AbstractAiTutorialTest { + + private static final String ALICE = "abc123"; + private static final String BOB = "qwertz"; + + @Override + protected String getTutorialDir() { + return "ai/llm-gateway/claude"; + } + + @Override + protected String getTutorialYaml() { + return "20-Sharing-API-Keys.yaml"; + } + + @Test + void aliceCanSendRequest() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", ALICE) + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(200) + .body("type", equalTo("message")); + // @formatter:on + } + + @Test + void bobCanSendRequest() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", BOB) + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(200) + .body("type", equalTo("message")); + // @formatter:on + } + + @Test + void unknownApiKeyIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", "invalid-key") + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(401) + .body("type", equalTo("error")) + .body("error.type", equalTo("authentication_error")); + // @formatter:on + } + + /** + * The gateway is configured with its own upstream {@code apiKey}. When a user request + * arrives carrying the user-facing key (e.g. alice's {@code abc123}), the gateway must + * replace it with the configured upstream key before forwarding to the LLM provider. + */ + @Test + void userApiKeyIsReplacedWithGatewayApiKey() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", ALICE) + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(200); + // @formatter:on + + assertThat(lastRequestApiKey, not(equalTo(ALICE))); + assertThat(lastRequestApiKey, equalTo("<>")); + } + + @Test + void wrongModelIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", ALICE) + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("wrong-model.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(400) + .body("type", equalTo("error")) + .body("error.type", equalTo("invalid_request_error")) + .body("error.message", containsString("gpt-5")) + .body("error.message", containsString("not allowed")); + // @formatter:on + } + + @Test + void inputTokenLimitExceededIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", ALICE) + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("max-input.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(400) + .body("type", equalTo("error")) + .body("error.type", equalTo("invalid_request_error")) + .body("error.message", containsString("prompt is too long")) + .body("error.message", containsString("100 maximum")); + // @formatter:on + } + + /** + * Alice has a budget of 250 tokens. Each request with {@code max-output.json} projects + * 7 (input estimate) + 200 (capped max_tokens) = 207 tokens. The mock returns 15 tokens + * of actual usage per call, so the running total grows by 15 after each response. + * + *

Budget accounting per request: + *

+     *   1st: 250 - 0   - 207 =  43  → forwarded; used becomes 15
+     *   2nd: 250 - 15  - 207 =  28  → forwarded; used becomes 30
+     *   3rd: 250 - 30  - 207 =  13  → forwarded; used becomes 45
+     *   4th: 250 - 45  - 207 =  -2  → rejected with 429
+     * 
+ * + * Bob's separate budget of 10 000 tokens is unaffected, so he can still send requests + * after alice is blocked. + */ + @Test + void alicesTokenBudgetIsExhaustedWhileBobIsUnaffected() throws IOException { + for (int i = 0; i < 3; i++) { + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", ALICE) + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("max-output.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(200); + // @formatter:on + } + + // Alice's budget is now exhausted + // @formatter:off + given() + .contentType("application/json") + .header("x-api-key", ALICE) + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("max-output.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(429) + .body("type", equalTo("error")) + .body("error.type", equalTo("rate_limit_error")); + + // Bob's budget is independent — he can still send requests + given() + .contentType("application/json") + .header("x-api-key", BOB) + .header("anthropic-version", "2023-06-01") + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/messages") + .then() + .statusCode(200) + .body("type", equalTo("message")); + // @formatter:on + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/AbstractGoogleTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/AbstractGoogleTutorialTest.java new file mode 100644 index 0000000000..4e39f7ae6c --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/AbstractGoogleTutorialTest.java @@ -0,0 +1,58 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.llmgateway.google; + +import com.predic8.membrane.tutorials.ai.llmgateway.AbstractAiTutorialTest; + +/** + * Base class for Google Gemini LLM-Gateway tutorial tests. + * + *

Overrides the upstream URL and the API-key header so the mock captures + * the {@code x-goog-api-key} header that Google uses. The mock response is + * formatted as a Gemini {@code generateContent} reply and reports 100 total + * tokens (50 prompt + 50 candidates) per call. + */ +public abstract class AbstractGoogleTutorialTest extends AbstractAiTutorialTest { + + /** URL prefix used in both Google tutorial YAML files. */ + @Override + protected String getUpstreamApiUrl() { + return "https://generativelanguage.googleapis.com"; + } + + @Override + protected String getTutorialDir() { + return "ai/llm-gateway/google"; + } + + /** Google authenticates via the {@code x-goog-api-key} header. */ + @Override + protected String getApiKeyHeader() { + return "x-goog-api-key"; + } + + /** + * Minimal Gemini {@code generateContent} reply with 50 prompt + 50 candidates = 100 total + * tokens. The higher per-request cost keeps the token-budget exhaustion test to three + * successful requests before alice's 500-token allowance runs out. + */ + @Override + protected String mockResponse() { + return """ + {"candidates":[{"content":{"parts":[{"text":"I am a mock."}],"role":"model"},\ + "finishReason":"STOP"}],\ + "usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":50,"totalTokenCount":100}}"""; + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/BasicGoogleLLMGatewayTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/BasicGoogleLLMGatewayTutorialTest.java new file mode 100644 index 0000000000..16f52d470b --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/BasicGoogleLLMGatewayTutorialTest.java @@ -0,0 +1,109 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.llmgateway.google; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static io.restassured.RestAssured.given; +import static io.restassured.path.json.JsonPath.from; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +/** + * Integration test for + * {@code distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml}. + * + *

The tutorial configures a Google Gemini LLM gateway with: + *

    + *
  • {@code maxInputTokens: 100} — requests whose estimated input exceeds 100 tokens are rejected
  • + *
  • {@code maxOutputTokens: 200} — {@code generationConfig.maxOutputTokens} in the forwarded + * request is capped to 200
  • + *
+ * + *

The upstream Google Gemini API is replaced by a local mock server so no real API key is needed. + */ +public class BasicGoogleLLMGatewayTutorialTest extends AbstractGoogleTutorialTest { + + private static final String GEMINI_ENDPOINT = + LOCALHOST_2000 + "/v1beta/models/gemini-2.5-flash:generateContent"; + + @Override + protected String getTutorialYaml() { + return "10-Basic-LLM-Gateway.yaml"; + } + + /** + * A request within the token limits is forwarded to the upstream and its response is returned. + */ + @Test + void simpleRequestIsForwarded() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", "test-key") + .body(readFileFromBaseDir("simple.json")) + .when() + .post(GEMINI_ENDPOINT) + .then() + .statusCode(200) + .body("candidates[0].content.parts[0].text", equalTo("I am a mock.")); + // @formatter:on + } + + /** + * A request whose message content exceeds {@code maxInputTokens} (100) is rejected by the + * gateway before reaching the upstream. The response uses the Google error format. + */ + @Test + void inputTokenLimitExceededIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", "test-key") + .body(readFileFromBaseDir("max-input.json")) + .when() + .post(GEMINI_ENDPOINT) + .then() + .statusCode(400) + .body("error.status", equalTo("INVALID_ARGUMENT")) + .body("error.message", containsString("exceeds the maximum allowed")) + .body("error.message", containsString("100")); + // @formatter:on + } + + /** + * When the request asks for more output tokens than {@code maxOutputTokens} (200) allows, + * the gateway rewrites {@code generationConfig.maxOutputTokens} to 200 before forwarding. + * The mock captures the forwarded body so we can verify the value was actually capped. + */ + @Test + void outputTokensAreCappedBeforeForwarding() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", "test-key") + .body(readFileFromBaseDir("max-output.json")) + .when() + .post(GEMINI_ENDPOINT) + .then() + .statusCode(200); + // @formatter:on + + assertThat(from(lastRequestBody).getInt("generationConfig.maxOutputTokens"), equalTo(200)); + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/SharingApiKeysGoogleTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/SharingApiKeysGoogleTutorialTest.java new file mode 100644 index 0000000000..567da88b9e --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/SharingApiKeysGoogleTutorialTest.java @@ -0,0 +1,219 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.llmgateway.google; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static io.restassured.RestAssured.given; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; + +/** + * Integration tests for + * {@code distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml}. + * + *

The tutorial demonstrates sharing a single upstream API key between multiple users, + * each identified by their own gateway key and subject to individual token budgets: + *

    + *
  • alice — key {@code abc123}, budget 500 tokens
  • + *
  • bob — key {@code qwertz}, budget 10 000 tokens
  • + *
+ * Additional gateway limits: {@code maxInputTokens=100}, {@code maxOutputTokens=200}, + * allowed models: {@code gemini-2.5-pro}, {@code gemini-2.5-flash}, {@code gemini-2.5-flash-lite}, + * {@code gemini-2.0-flash}, {@code gemini-2.0-flash-lite}. + * + *

For Google Gemini the model is part of the URL path + * ({@code /v1beta/models/:generateContent}), not the request body. + */ +public class SharingApiKeysGoogleTutorialTest extends AbstractGoogleTutorialTest { + + private static final String ALICE = "abc123"; + private static final String BOB = "qwertz"; + + private static final String GEMINI_FLASH_ENDPOINT = + LOCALHOST_2000 + "/v1beta/models/gemini-2.5-flash:generateContent"; + + @Override + protected String getTutorialYaml() { + return "20-Sharing-API-Keys.yaml"; + } + + @Test + void aliceCanSendRequest() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", ALICE) + .body(readFileFromBaseDir("simple.json")) + .when() + .post(GEMINI_FLASH_ENDPOINT) + .then() + .statusCode(200) + .body("candidates[0].content.parts[0].text", equalTo("I am a mock.")); + // @formatter:on + } + + @Test + void bobCanSendRequest() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", BOB) + .body(readFileFromBaseDir("simple.json")) + .when() + .post(GEMINI_FLASH_ENDPOINT) + .then() + .statusCode(200) + .body("candidates[0].content.parts[0].text", equalTo("I am a mock.")); + // @formatter:on + } + + @Test + void unknownApiKeyIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", "invalid-key") + .body(readFileFromBaseDir("simple.json")) + .when() + .post(GEMINI_FLASH_ENDPOINT) + .then() + .statusCode(401) + .body("error.status", equalTo("UNAUTHENTICATED")) + .body("error.message", containsString("Invalid API key")); + // @formatter:on + } + + /** + * The gateway is configured with its own upstream {@code apiKey}. When a user request + * arrives carrying the user-facing key (e.g. alice's {@code abc123}), the gateway must + * replace it with the configured upstream key before forwarding to the LLM provider. + * For Google Gemini, the key is carried in the {@code x-goog-api-key} header. + */ + @Test + void userApiKeyIsReplacedWithGatewayApiKey() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", ALICE) + .body(readFileFromBaseDir("simple.json")) + .when() + .post(GEMINI_FLASH_ENDPOINT) + .then() + .log().ifValidationFails() + .statusCode(200); + // @formatter:on + + assertThat(lastRequestApiKey, not(equalTo(ALICE))); + assertThat(lastRequestApiKey, equalTo("<>")); + } + + /** + * For Google Gemini the model is extracted from the URL path. Sending a request to + * {@code /v1beta/models/gpt-5:generateContent} uses model {@code gpt-5}, which is not + * in the allowed list, so the gateway rejects it with 400. + */ + @Test + void wrongModelIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", ALICE) + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1beta/models/gpt-5:generateContent") + .then() + .statusCode(400) + .body("error.status", equalTo("INVALID_ARGUMENT")) + .body("error.message", containsString("gpt-5")) + .body("error.message", containsString("not allowed")); + // @formatter:on + } + + @Test + void inputTokenLimitExceededIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", ALICE) + .body(readFileFromBaseDir("max-input.json")) + .when() + .post(GEMINI_FLASH_ENDPOINT) + .then() + .statusCode(400) + .body("error.status", equalTo("INVALID_ARGUMENT")) + .body("error.message", containsString("exceeds the maximum allowed")) + .body("error.message", containsString("100")); + // @formatter:on + } + + /** + * Alice has a budget of 500 tokens. Each request with {@code max-output.json} projects + * 9 (input estimate) + 200 (capped maxOutputTokens) = 209 tokens. The mock returns + * 100 tokens of actual usage per call, so the running total grows by 100 after each response. + * + *

Budget accounting per request: + *

+     *   1st: 500 - 0   - 209 = 291  → forwarded; used becomes 100
+     *   2nd: 500 - 100 - 209 = 191  → forwarded; used becomes 200
+     *   3rd: 500 - 200 - 209 =  91  → forwarded; used becomes 300
+     *   4th: 500 - 300 - 209 =  -9  → rejected with 429
+     * 
+ * + * Bob's separate budget of 10 000 tokens is unaffected, so he can still send requests + * after alice is blocked. + */ + @Test + void alicesTokenBudgetIsExhaustedWhileBobIsUnaffected() throws IOException { + for (int i = 0; i < 3; i++) { + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", ALICE) + .body(readFileFromBaseDir("max-output.json")) + .when() + .post(GEMINI_FLASH_ENDPOINT) + .then() + .statusCode(200); + // @formatter:on + } + + // Alice's budget is now exhausted + // @formatter:off + given() + .contentType("application/json") + .header("x-goog-api-key", ALICE) + .body(readFileFromBaseDir("max-output.json")) + .when() + .post(GEMINI_FLASH_ENDPOINT) + .then() + .statusCode(429) + .body("error.status", equalTo("RESOURCE_EXHAUSTED")); + + // Bob's budget is independent — he can still send requests + given() + .contentType("application/json") + .header("x-goog-api-key", BOB) + .body(readFileFromBaseDir("simple.json")) + .when() + .post(GEMINI_FLASH_ENDPOINT) + .then() + .statusCode(200) + .body("candidates[0].content.parts[0].text", equalTo("I am a mock.")); + // @formatter:on + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/AbstractOpenAiTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/AbstractOpenAiTutorialTest.java new file mode 100644 index 0000000000..54136f4c2f --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/AbstractOpenAiTutorialTest.java @@ -0,0 +1,61 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.llmgateway.openai; + +import com.predic8.membrane.tutorials.ai.llmgateway.AbstractAiTutorialTest; + +/** + * Base class for OpenAI LLM-Gateway tutorial tests. + * + *

Overrides the upstream URL and the API-key header so the mock captures + * the {@code Authorization} header that OpenAI uses instead of {@code x-api-key}. + * The mock response is formatted as an OpenAI Responses-API reply and reports + * 100 total tokens (50 input + 50 output) per call. + */ +public abstract class AbstractOpenAiTutorialTest extends AbstractAiTutorialTest { + + @Override + protected String getTutorialDir() { + return "ai/llm-gateway/openai"; + } + + @Override + protected String getUpstreamApiUrl() { + return "https://api.openai.com"; + } + + /** + * OpenAI authenticates via {@code Authorization: Bearer }. + * The full header value (including the "Bearer " prefix) is captured. + */ + @Override + protected String getApiKeyHeader() { + return "authorization"; + } + + /** + * Minimal OpenAI Responses-API reply with 50 input + 50 output = 100 total tokens. + * The higher per-request cost (vs. the default Claude mock) keeps the token-budget + * exhaustion test to three successful requests before alice's 500-token allowance runs out. + */ + @Override + protected String mockResponse() { + return """ + {"id":"resp_mock","object":"response","model":"gpt-5-nano",\ + "output":[{"type":"message","role":"assistant",\ + "content":[{"type":"output_text","text":"I am a mock."}]}],\ + "usage":{"input_tokens":50,"output_tokens":50,"total_tokens":100}}"""; + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/BasicOpenAiLLMGatewayTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/BasicOpenAiLLMGatewayTutorialTest.java new file mode 100644 index 0000000000..6dd96ee098 --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/BasicOpenAiLLMGatewayTutorialTest.java @@ -0,0 +1,105 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.llmgateway.openai; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static io.restassured.RestAssured.given; +import static io.restassured.path.json.JsonPath.from; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +/** + * Integration test for + * {@code distribution/tutorials/ai/llm-gateway/openai/10-Basic-LLM-Gateway.yaml}. + * + *

The tutorial configures an OpenAI LLM gateway with: + *

    + *
  • {@code maxInputTokens: 100} — requests whose estimated input exceeds 100 tokens are rejected
  • + *
  • {@code maxOutputTokens: 200} — {@code max_output_tokens} in the forwarded request is capped to 200
  • + *
+ * + *

The upstream OpenAI API is replaced by a local mock server so no real API key is needed. + */ +public class BasicOpenAiLLMGatewayTutorialTest extends AbstractOpenAiTutorialTest { + + @Override + protected String getTutorialYaml() { + return "10-Basic-LLM-Gateway.yaml"; + } + + /** + * A request within the token limits is forwarded to the upstream and its response is returned. + */ + @Test + void simpleRequestIsForwarded() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer test-key") + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(200) + .body("object", equalTo("response")); + // @formatter:on + } + + /** + * A request whose message content exceeds {@code maxInputTokens} (100) is rejected by the + * gateway before reaching the upstream. The response uses the OpenAI error format. + */ + @Test + void inputTokenLimitExceededIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer test-key") + .body(readFileFromBaseDir("max-input.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(400) + .body("error.type", equalTo("invalid_request_error")) + .body("error.code", equalTo("context_length_exceeded")) + .body("error.message", containsString("100")); + // @formatter:on + } + + /** + * When the request asks for more output tokens than {@code maxOutputTokens} (200) allows, + * the gateway rewrites {@code max_output_tokens} to 200 before forwarding to the upstream. + * The mock captures the forwarded body so we can verify the value was actually capped. + */ + @Test + void outputTokensAreCappedBeforeForwarding() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer test-key") + .body(readFileFromBaseDir("max-output.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(200); + // @formatter:on + + assertThat(from(lastRequestBody).getInt("max_output_tokens"), equalTo(200)); + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java new file mode 100644 index 0000000000..7bd410fc24 --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java @@ -0,0 +1,208 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.llmgateway.openai; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static io.restassured.RestAssured.given; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; + +/** + * Integration tests for + * {@code distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml}. + * + *

The tutorial demonstrates sharing a single upstream API key between multiple users, + * each identified by their own gateway key and subject to individual token budgets: + *

    + *
  • alice — key {@code abc123}, budget 500 tokens
  • + *
  • bob — key {@code qwertz}, budget 10 000 tokens
  • + *
+ * Additional gateway limits: {@code maxInputTokens=100}, {@code maxOutputTokens=200}, + * allowed models: {@code gpt-5.4}, {@code gpt-5-nano}, {@code gpt-5-mini}. + */ +public class SharingApiKeysOpenAiTutorialTest extends AbstractOpenAiTutorialTest { + + private static final String ALICE = "abc123"; + private static final String BOB = "qwertz"; + + @Override + protected String getTutorialYaml() { + return "20-Sharing-API-Keys.yaml"; + } + + @Test + void aliceCanSendRequest() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer " + ALICE) + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(200) + .body("object", equalTo("response")); + // @formatter:on + } + + @Test + void bobCanSendRequest() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer " + BOB) + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(200) + .body("object", equalTo("response")); + // @formatter:on + } + + @Test + void unknownApiKeyIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer invalid-key") + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(401) + .body("error.code", equalTo("invalid_authentication")); + // @formatter:on + } + + /** + * The gateway is configured with its own upstream {@code apiKey}. When a user request + * arrives carrying the user-facing key (e.g. alice's {@code abc123}), the gateway must + * replace it with the configured upstream key before forwarding to the LLM provider. + * For OpenAI, the key is carried in the {@code Authorization: Bearer } header. + */ + @Test + void userApiKeyIsReplacedWithGatewayApiKey() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer " + ALICE) + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(200); + // @formatter:on + + assertThat(lastRequestApiKey, not(equalTo("Bearer " + ALICE))); + assertThat(lastRequestApiKey, equalTo("Bearer <>")); + } + + @Test + void wrongModelIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer " + ALICE) + .body(readFileFromBaseDir("wrong-model.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(400) + .body("error.type", equalTo("invalid_request_error")) + .body("error.code", equalTo("model_not_allowed")) + .body("error.message", containsString("gpt-4")) + .body("error.message", containsString("not allowed")); + // @formatter:on + } + + @Test + void inputTokenLimitExceededIsRejected() throws IOException { + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer " + ALICE) + .body(readFileFromBaseDir("max-input.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(400) + .body("error.type", equalTo("invalid_request_error")) + .body("error.code", equalTo("context_length_exceeded")) + .body("error.message", containsString("maximum context length")) + .body("error.message", containsString("100")); + // @formatter:on + } + + /** + * Alice has a budget of 500 tokens. Each request with {@code max-output.json} projects + * 9 (input estimate) + 200 (capped max_output_tokens) = 209 tokens. The mock returns + * 100 tokens of actual usage per call, so the running total grows by 100 after each response. + * + *

Budget accounting per request: + *

+     *   1st: 500 - 0   - 209 = 291  → forwarded; used becomes 100
+     *   2nd: 500 - 100 - 209 = 191  → forwarded; used becomes 200
+     *   3rd: 500 - 200 - 209 =  91  → forwarded; used becomes 300
+     *   4th: 500 - 300 - 209 =  -9  → rejected with 429
+     * 
+ * + * Bob's separate budget of 10 000 tokens is unaffected, so he can still send requests + * after alice is blocked. + */ + @Test + void alicesTokenBudgetIsExhaustedWhileBobIsUnaffected() throws IOException { + for (int i = 0; i < 3; i++) { + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer " + ALICE) + .body(readFileFromBaseDir("max-output.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(200); + // @formatter:on + } + + // Alice's budget is now exhausted + // @formatter:off + given() + .contentType("application/json") + .header("Authorization", "Bearer " + ALICE) + .body(readFileFromBaseDir("max-output.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(429) + .body("error.type", equalTo("rate_limit_error")) + .body("error.code", equalTo("token_limit_exceeded")); + + // Bob's budget is independent — he can still send requests + given() + .contentType("application/json") + .header("Authorization", "Bearer " + BOB) + .body(readFileFromBaseDir("simple.json")) + .when() + .post(LOCALHOST_2000 + "/v1/responses") + .then() + .statusCode(200) + .body("object", equalTo("response")); + // @formatter:on + } +} From 9d38d849ef3afb0d3aa6b5ba9d21454dd2713c22 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 21 May 2026 15:54:03 +0200 Subject: [PATCH 30/43] feat: improve token handling, configuration validation, and examples for LLM Gateway - Ensure thread-safe access to users in `SimpleAiApiStore` with `List.copyOf`. - Introduce `visibleRemaining` to handle non-negative token values in `GoogleErrorCreator`. - Add configuration validation in `LLMGatewayInterceptor` to enforce API key substitution. - Enhance token limit handling to adjust output tokens dynamically in `LLMGatewayInterceptor`. - Update Google and Claude tutorials with clearer instructions for API key usage and token limits. --- .../interceptor/ai/LLMGatewayInterceptor.java | 28 +++++++++++++++---- .../provider/google/GoogleErrorCreator.java | 4 ++- .../ai/store/SimpleAiApiStore.java | 2 +- .../claude/20-Sharing-API-Keys.yaml | 2 +- .../google/10-Basic-LLM-Gateway.yaml | 2 +- .../google/20-Sharing-API-Keys.yaml | 2 +- 6 files changed, 30 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index 41c536c39b..8a63100c08 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -11,6 +11,7 @@ import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; import com.predic8.membrane.core.interceptor.ai.store.AiApiStore; import com.predic8.membrane.core.interceptor.ai.store.AiApiUser; +import com.predic8.membrane.core.util.ConfigurationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -55,6 +56,11 @@ public void init() { errorCreator = provider.getErrorCreator(); if (store != null) store.init(router); + + // Check if the replacement markers are still there + if (apiKey.contains("<<") && apiKey.contains(">>")) { + throw new ConfigurationException("The configuration contains the replacement marker %s. Substitute it with the API key of the model.".formatted(apiKey)); + } } @Override @@ -69,7 +75,8 @@ public Outcome handleRequest(Exchange exc) { } if (!exc.getRequest().isPOSTRequest()) { - aiReq.setApiKey(apiKey); + if (apiKey != null) + aiReq.setApiKey(apiKey); return CONTINUE; } @@ -90,7 +97,7 @@ public Outcome handleRequest(Exchange exc) { // Check store limits if (store != null) { - var effectiveMaxTokens = Math.min(aiReq.getRequestedMaxOutputTokens(), maxOutputTokens); + var effectiveMaxTokens = computeEffectiveMaxOutputTokens(aiReq.getRequestedMaxOutputTokens(), maxOutputTokens); var remaining = store.checkLimit(user, inputTokens, effectiveMaxTokens); log.debug("User {} has {} remaining tokens left", user, remaining); if (remaining <= 0) { @@ -109,9 +116,14 @@ public Outcome handleRequest(Exchange exc) { var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); - if (maxOutputTokens != 0 && (requestedMaxOutputTokens == -1 || requestedMaxOutputTokens > maxOutputTokens)) { - log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.", requestedMaxOutputTokens, maxOutputTokens); - aiReq.setMaxOutputTokens(maxOutputTokens); + if (maxOutputTokens > 0) { + if (requestedMaxOutputTokens == -1) { + log.info("No max. output requested. Setting limit to {}.", maxOutputTokens); + aiReq.setMaxOutputTokens(maxOutputTokens); + } else if (requestedMaxOutputTokens > maxOutputTokens) { + log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.", requestedMaxOutputTokens, maxOutputTokens); + aiReq.setMaxOutputTokens(maxOutputTokens); + } } if (maxInputTokens != 0) { @@ -136,6 +148,12 @@ public Outcome handleRequest(Exchange exc) { return CONTINUE; } + long computeEffectiveMaxOutputTokens(long requestedMaxOutputTokens, long maxOutputTokens) { + if (requestedMaxOutputTokens <= 0) + return maxOutputTokens; + return Math.min(requestedMaxOutputTokens, maxOutputTokens); + } + @Override public Outcome handleResponse(Exchange exc) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java index 281a314594..92a115194a 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java @@ -21,6 +21,8 @@ public Response tokenLimitExceeded(long tokenRequired, long tokenRemaining, long tokenResetInSeconds) { + var visibleRemaining = Math.max(0, tokenRemaining); + return statusCode(429).json( envelope( 429, @@ -29,7 +31,7 @@ public Response tokenLimitExceeded(long tokenRequired, Request requires %d tokens but only %d remain. Retry after %d seconds. """ - .formatted(tokenRequired, tokenRemaining, tokenResetInSeconds) + .formatted(tokenRequired, visibleRemaining, tokenResetInSeconds) .trim(), "RESOURCE_EXHAUSTED" ) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java index 0793a2c75d..f9b8218608 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java @@ -87,7 +87,7 @@ public void setUsers(List users) { public List getUsers() { synchronized (lock) { - return users; + return List.copyOf(users); } } diff --git a/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml index 6cb3f0c0b3..44bc28a8ee 100644 --- a/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml +++ b/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml @@ -24,7 +24,7 @@ # # 5. Requested Max. Output Tokens Exceeded # curl -v -H "Content-Type: application/json" -H "x-api-key: abc123" -H "anthropic-version: 2023-06-01" -d @max-output.json http://localhost:2000/v1/messages -# Check Membrane log: totalTokens should not exceed 200 even it was requested in max-output.json +# Check Membrane log: totalTokens should not exceed 200 even though it was requested in max-output.json api: port: 2000 diff --git a/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml index ce7ab99ef6..a86eec6a27 100644 --- a/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml +++ b/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml @@ -2,7 +2,7 @@ # # Tutorial: Basic LLM Gateway (Google Gemini) # -# Replace <> with your OpenAI API key. +# Replace <> with your Google API key. # # 1. Hello World # curl -v -H "Content-Type: application/json" -H "x-goog-api-key: <>" -d @simple.json http://localhost:2000/v1beta/models/gemini-2.5-flash:generateContent diff --git a/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml index 89db5cd5b7..0b7c8569e7 100644 --- a/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml +++ b/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml @@ -24,7 +24,7 @@ # # 5. Requested Max. Output Tokens Exceeded # curl -v -H "Content-Type: application/json" -H "x-goog-api-key: abc123" -d @max-output.json http://localhost:2000/v1beta/models/gemini-2.5-flash:generateContent -# Check Membrane log: totalTokens should not exceed 200 even it was requested in max-output.json +# Check Membrane log: totalTokens should not exceed 200 even though it was requested in max-output.json api: port: 2000 From df3d145a5066234218e2bf0b6fa486954444d7f1 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 22 May 2026 08:01:09 +0200 Subject: [PATCH 31/43] feat: add streaming integration tests for OpenAI in LLM Gateway - Introduced `StreamingOpenAiLLMGatewayTutorialTest` with SSE mocking and validation. - Added JSON fixtures (`stream.json`, `max-output-stream.json`) for testing streaming requests. - Enhanced base test framework to support `text/event-stream` responses. - Updated `LLMGatewayInterceptor` to handle streaming scenarios with capped tokens. --- .../interceptor/ai/LLMGatewayInterceptor.java | 3 +- .../ai/llmgateway/AbstractAiTutorialTest.java | 12 +- ...StreamingOpenAiLLMGatewayTutorialTest.java | 135 ++++++++++++++++++ .../llm-gateway/openai/max-output-stream.json | 6 + .../ai/llm-gateway/openai/stream.json | 5 + 5 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/StreamingOpenAiLLMGatewayTutorialTest.java create mode 100644 distribution/tutorials/ai/llm-gateway/openai/max-output-stream.json create mode 100644 distribution/tutorials/ai/llm-gateway/openai/stream.json diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index 8a63100c08..89cc3d2ede 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -53,12 +53,13 @@ public class LLMGatewayInterceptor extends AbstractInterceptor { @Override public void init() { + super.init(); errorCreator = provider.getErrorCreator(); if (store != null) store.init(router); // Check if the replacement markers are still there - if (apiKey.contains("<<") && apiKey.contains(">>")) { + if (apiKey != null && apiKey.contains("<<") && apiKey.contains(">>")) { throw new ConfigurationException("The configuration contains the replacement marker %s. Substitute it with the API key of the model.".formatted(apiKey)); } } diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java index fb5a37a07c..ded96258e5 100644 --- a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java @@ -29,6 +29,8 @@ import java.util.function.Consumer; +import static com.predic8.membrane.core.http.MimeType.APPLICATION_JSON; + /** * Base class for AI tutorial tests. Starts a local Membrane mock of the upstream LLM API * so tests run without a real API key and without network access to the LLM provider. @@ -103,10 +105,18 @@ protected String getApiKeyHeader() { return "x-api-key"; } + /** + * Content-Type the mock LLM server sends back. Defaults to {@code "application/json"} + * for regular responses. Override to {@code "text/event-stream"} in streaming test classes. + */ + protected String mockContentType() { + return APPLICATION_JSON; + } + private void startMockLlmApi() throws Exception { var si = new StaticInterceptor(); si.setSrc(mockResponse()); - si.setContentType("application/json"); + si.setContentType(mockContentType()); var sp = new ServiceProxy(new ServiceProxyKey(MOCK_LLM_PORT), null, 0); sp.getFlow().add(new BodyCaptureInterceptor( diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/StreamingOpenAiLLMGatewayTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/StreamingOpenAiLLMGatewayTutorialTest.java new file mode 100644 index 0000000000..679cfca6a7 --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/StreamingOpenAiLLMGatewayTutorialTest.java @@ -0,0 +1,135 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.llmgateway.openai; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; + +import static io.restassured.path.json.JsonPath.from; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Integration tests for the streaming (SSE) path of + * {@code distribution/tutorials/ai/llm-gateway/openai/10-Basic-LLM-Gateway.yaml}. + * + *

The mock upstream returns {@code Content-Type: text/event-stream} with three + * SSE events so the gateway's SSE processing path is exercised end-to-end without + * a real OpenAI connection: + * + *

    + *
  • {@code response.created} — initial acknowledgement
  • + *
  • {@code response.output_text.delta} — incremental text chunk
  • + *
  • {@code response.completed} — terminal event carrying usage statistics
  • + *
+ * + *

Because RestAssured does not handle server-sent events well, these tests use the + * Java {@link java.net.http.HttpClient} directly — the same approach used in + * {@code ServerSentEventsTutorialTest}. + */ +public class StreamingOpenAiLLMGatewayTutorialTest extends AbstractOpenAiTutorialTest { + + private static final String RESPONSES_ENDPOINT = LOCALHOST_2000 + "/v1/responses"; + + @Override + protected String getTutorialYaml() { + return "10-Basic-LLM-Gateway.yaml"; + } + + /** Tell the mock server to respond as a finite SSE stream. */ + @Override + protected String mockContentType() { + return "text/event-stream"; + } + + /** + * A minimal but complete SSE body: one delta event followed by the terminal + * {@code response.completed} event that carries the usage node the gateway + * reads for token accounting. + */ + @Override + protected String mockResponse() { + return """ + event: response.created + data: {"type":"response.created","response":{"id":"resp_mock","object":"response","status":"in_progress","model":"gpt-5-nano"}} + + event: response.output_text.delta + data: {"type":"response.output_text.delta","item_id":"msg_mock","output_index":0,"content_index":0,"delta":"I am a mock."} + + event: response.completed + data: {"type":"response.completed","response":{"id":"resp_mock","object":"response","status":"completed","model":"gpt-5-nano","output":[{"type":"message","id":"msg_mock","status":"completed","role":"assistant","content":[{"type":"output_text","text":"I am a mock."}]}],"usage":{"input_tokens":50,"output_tokens":50,"total_tokens":100}}} + + """; + } + + /** + * The gateway must forward a streaming request and pass the {@code text/event-stream} + * response through to the client intact. The response body must contain the SSE events + * emitted by the upstream, including the delta text. + */ + @Test + void streamingResponseIsForwarded() throws IOException, InterruptedException { + var response = sendStreamingRequest("stream.json"); + + assertEquals(200, response.statusCode()); + assertTrue(response.headers().firstValue("content-type").orElse("").contains("text/event-stream"), + "Expected Content-Type text/event-stream"); + assertTrue(response.body().contains("response.output_text.delta"), + "SSE body must contain the delta event name"); + assertTrue(response.body().contains("I am a mock."), + "SSE body must contain the delta text"); + assertTrue(response.body().contains("response.completed"), + "SSE body must contain the terminal event"); + } + + /** + * When the request carries {@code "max_output_tokens": 500} and the gateway is + * configured with {@code maxOutputTokens: 200}, the gateway must rewrite the field + * to 200 before forwarding — even for streaming requests. + * + *

The mock captures the forwarded request body so we can assert the capped value. + */ + @Test + void streamingOutputTokensAreCappedBeforeForwarding() throws IOException, InterruptedException { + var response = sendStreamingRequest("max-output-stream.json"); + + assertEquals(200, response.statusCode()); + assertThat(from(lastRequestBody).getInt("max_output_tokens"), equalTo(200)); + } + + // ------------------------------------------------------------------------- + + private HttpResponse sendStreamingRequest(String fixture) throws IOException, InterruptedException { + var request = HttpRequest.newBuilder() + .uri(URI.create(RESPONSES_ENDPOINT)) + .timeout(Duration.ofSeconds(10)) + .header("Content-Type", "application/json") + .header("Authorization", "Bearer test-key") + .POST(HttpRequest.BodyPublishers.ofString(readFileFromBaseDir(fixture))) + .build(); + + try (var client = HttpClient.newHttpClient()) { + return client.send(request, HttpResponse.BodyHandlers.ofString()); + } + } +} diff --git a/distribution/tutorials/ai/llm-gateway/openai/max-output-stream.json b/distribution/tutorials/ai/llm-gateway/openai/max-output-stream.json new file mode 100644 index 0000000000..0a747d70e4 --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/openai/max-output-stream.json @@ -0,0 +1,6 @@ +{ + "model": "gpt-5-nano", + "input": "Explain in detail who you are?", + "max_output_tokens": 500, + "stream": true +} diff --git a/distribution/tutorials/ai/llm-gateway/openai/stream.json b/distribution/tutorials/ai/llm-gateway/openai/stream.json new file mode 100644 index 0000000000..1c75ce00aa --- /dev/null +++ b/distribution/tutorials/ai/llm-gateway/openai/stream.json @@ -0,0 +1,5 @@ +{ + "model": "gpt-5-nano", + "input": "Who are you?", + "stream": true +} From d088dab34f900872ac4d08e3e5bbef233ac2bd92 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 22 May 2026 09:01:56 +0200 Subject: [PATCH 32/43] feat: standardize API key handling and logging in LLM Gateway tests - Replaced raw API key placeholders with `TEST_API_KEY` constant in tutorial tests to ensure consistency. - Added `TEST_API_KEY` to `AbstractAiTutorialTest` for upstream key substitution verification. - Updated `log4j2.xml` to limit logging to `com.predic8.membrane.core.interceptor.ai`. - Introduced PostgreSQL dependency in `pom.xml` for future enhancements. --- .../tutorials/ai/llmgateway/AbstractAiTutorialTest.java | 8 ++++++++ .../ai/llmgateway/claude/SharingApiKeysTutorialTest.java | 2 +- .../google/SharingApiKeysGoogleTutorialTest.java | 2 +- .../openai/SharingApiKeysOpenAiTutorialTest.java | 2 +- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java index ded96258e5..77c674ee1e 100644 --- a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/AbstractAiTutorialTest.java @@ -47,6 +47,13 @@ public abstract class AbstractAiTutorialTest extends DistributionExtractingTestc protected static final int MOCK_LLM_PORT = 3100; + /** + * Value substituted for the {@code <>} placeholder in tutorial + * YAMLs before Membrane starts. Tests that verify upstream key-substitution assert against + * this constant instead of the raw placeholder text. + */ + protected static final String TEST_API_KEY = "test-upstream-key"; + protected Process2 process; protected volatile String lastRequestBody; protected volatile String lastRequestApiKey; @@ -74,6 +81,7 @@ protected String getParameters() { void startGateway() throws Exception { startMockLlmApi(); replaceInFile2(getTutorialYaml(), getUpstreamApiUrl(), mockApiUrl()); + replaceInFile2(getTutorialYaml(), "<>", TEST_API_KEY); process = startServiceProxyScript(); } diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/SharingApiKeysTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/SharingApiKeysTutorialTest.java index 33174d6ac8..3514870774 100644 --- a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/SharingApiKeysTutorialTest.java +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/SharingApiKeysTutorialTest.java @@ -120,7 +120,7 @@ void userApiKeyIsReplacedWithGatewayApiKey() throws IOException { // @formatter:on assertThat(lastRequestApiKey, not(equalTo(ALICE))); - assertThat(lastRequestApiKey, equalTo("<>")); + assertThat(lastRequestApiKey, equalTo(TEST_API_KEY)); } @Test diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/SharingApiKeysGoogleTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/SharingApiKeysGoogleTutorialTest.java index 567da88b9e..79b1a71e3e 100644 --- a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/SharingApiKeysGoogleTutorialTest.java +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/google/SharingApiKeysGoogleTutorialTest.java @@ -119,7 +119,7 @@ void userApiKeyIsReplacedWithGatewayApiKey() throws IOException { // @formatter:on assertThat(lastRequestApiKey, not(equalTo(ALICE))); - assertThat(lastRequestApiKey, equalTo("<>")); + assertThat(lastRequestApiKey, equalTo(TEST_API_KEY)); } /** diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java index 7bd410fc24..88a6d380ad 100644 --- a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java @@ -110,7 +110,7 @@ void userApiKeyIsReplacedWithGatewayApiKey() throws IOException { // @formatter:on assertThat(lastRequestApiKey, not(equalTo("Bearer " + ALICE))); - assertThat(lastRequestApiKey, equalTo("Bearer <>")); + assertThat(lastRequestApiKey, equalTo("Bearer " + TEST_API_KEY)); } @Test From ab2fac928c0ae7713a036bf257594090328b5f66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20G=C3=B6rdes?= Date: Fri, 22 May 2026 12:49:27 +0200 Subject: [PATCH 33/43] Update tutorial to use Anthropic-specific API key and headers --- .../ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml index 8931dfd28a..5b30514e2f 100644 --- a/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml +++ b/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml @@ -2,18 +2,18 @@ # # Tutorial: Basic LLM Gateway (Antropic Claude) # -# Replace <> with your OpenAI API key. +# Replace <> with your Claude API key. # # 1. Hello World # curl -v -H "Content-Type: application/json" -H "x-api-key: <>" -H "anthropic-version: 2023-06-01" -d @simple.json http://localhost:2000/v1/messages # Check the response and the Membrane logs. # # 2. Exceed the input token limit -# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-input.json http://localhost:2000/v1/messages +# curl -v -H "Content-Type: application/json" -H "x-api-key: <>" -H "anthropic-version: 2023-06-01" -d @max-input.json http://localhost:2000/v1/messages # Returns an error because the request exceeds maxInputTokens. # # 3. Exceed the output token limit -# curl -H "Content-Type: application/json" -H "Authorization: Bearer <>" -d @max-output.json http://localhost:2000/v1/messages +# curl -v -H "Content-Type: application/json" -H "x-api-key: <>" -H "anthropic-version: 2023-06-01" -d @max-output.json http://localhost:2000/v1/messages # Check the Membrane log for limiting max tokens to 200 api: From 37f7515eaaa97c1e9b9c133703948dd5168e70ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20G=C3=B6rdes?= Date: Fri, 22 May 2026 14:30:02 +0200 Subject: [PATCH 34/43] Fix handling of invalid max output token requests in LLMGatewayInterceptor --- .../membrane/core/interceptor/ai/LLMGatewayInterceptor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java index 89cc3d2ede..4542c16548 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java @@ -118,7 +118,7 @@ public Outcome handleRequest(Exchange exc) { var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); if (maxOutputTokens > 0) { - if (requestedMaxOutputTokens == -1) { + if (requestedMaxOutputTokens <= 0) { log.info("No max. output requested. Setting limit to {}.", maxOutputTokens); aiReq.setMaxOutputTokens(maxOutputTokens); } else if (requestedMaxOutputTokens > maxOutputTokens) { From 44850fe510f33de3ccc67bc98eb78c38752c0803 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 26 May 2026 11:59:47 +0200 Subject: [PATCH 35/43] refactor: migrate OpenAI provider to Chat Completions framework and add policies support - Unified OpenAI and Chat Completions error handling under `ChatCompletionsErrorCreator`. - Deprecated older OpenAI-specific classes in favor of `ChatCompletions` equivalents. - Introduced detailed usage policies handling in `LLMGatewayInterceptor`. - Updated YAML tutorials to reflect the new `policies` configuration model. --- .../ai/ChatCompletionDoneEvent.java | 15 ---- .../ai/provider/AbstractLLMMessage.java | 25 ------- .../ai/provider/LLMErrorCreator.java | 24 ------ .../interceptor/ai/provider/LLMProvider.java | 13 ---- .../interceptor/ai/provider/LLMRequest.java | 29 -------- .../interceptor/ai/provider/LLMResponse.java | 18 ----- .../ai/provider/claude/ClaudeProvider.java | 33 --------- .../ai/provider/claude/ContentBlockStart.java | 23 ------ .../ai/provider/claude/ToolUse.java | 22 ------ .../ai/provider/google/GoogleProvider.java | 33 --------- .../ai/provider/openai/OpenAIProvider.java | 44 ----------- .../OpenAiChatCompletionsLLMResponse.java | 55 -------------- .../core/interceptor/ai/store/AiApiStore.java | 29 -------- .../core/interceptor/ai/store/Usage.java | 3 - .../{ai => llmgateway}/AbstractLLMEvent.java | 16 +++- .../llmgateway/ChatCompletionDoneEvent.java | 29 ++++++++ .../ChatCompletionEvent.java | 16 +++- .../LLMGatewayInterceptor.java | 74 ++++++++++--------- .../{ai => llmgateway}/ResponsesApiEvent.java | 16 +++- .../provider/AbstractLLMErrorCreator.java | 16 +++- .../provider/AbstractLLMMessage.java | 39 ++++++++++ .../provider/AbstractLLMRequest.java | 16 +++- .../provider/AbstractLLMResponse.java | 16 +++- .../llmgateway/provider/LLMErrorCreator.java | 38 ++++++++++ .../llmgateway/provider/LLMProvider.java | 27 +++++++ .../llmgateway/provider/LLMRequest.java | 43 +++++++++++ .../llmgateway/provider/LLMResponse.java | 32 ++++++++ .../ChatCompletionsErrorCreator.java} | 20 ++++- .../ChatCompletionsProvider.java | 62 ++++++++++++++++ .../ChatCompletionsRequest.java} | 23 +++++- .../ChatCompletionsResponse.java | 69 +++++++++++++++++ .../provider/claude/ClaudeErrorCreator.java | 20 ++++- .../provider/claude/ClaudeErrorResponse.java | 16 +++- .../provider/claude/ClaudeLLMRequest.java | 18 ++++- .../provider/claude/ClaudeLLMResponse.java | 22 +++++- .../provider/claude/ClaudeProvider.java | 47 ++++++++++++ .../provider/claude/ContentBlockDelta.java | 16 +++- .../provider/claude/ContentBlockStart.java | 37 ++++++++++ .../provider/claude/MessageDelta.java | 18 ++++- .../llmgateway/provider/claude/ToolUse.java | 36 +++++++++ .../provider/google/GoogleErrorCreator.java | 18 ++++- .../provider/google/GoogleLLMRequest.java | 18 ++++- .../provider/google/GoogleLLMResponse.java | 22 +++++- .../provider/google/GoogleProvider.java | 47 ++++++++++++ .../openai/AbstractOpenAiLLMRequest.java | 18 ++++- .../openai/OpenAIChatCompletionsRequest.java | 29 ++++++++ .../provider/openai/OpenAIProvider.java | 59 +++++++++++++++ .../openai/OpenAiLLMResponsesRequest.java | 16 +++- .../openai/OpenAiLLMResponsesResponse.java | 24 ++++-- .../llmgateway/store/AiApiStore.java | 43 +++++++++++ .../{ai => llmgateway}/store/AiApiUser.java | 16 +++- .../store/JDBCAiApiUsageStore.java | 18 ++++- .../store/SimpleAiApiStore.java | 16 +++- .../interceptor/llmgateway/store/Usage.java | 17 +++++ .../membrane/core/util/http/SSEParser.java | 14 ++++ .../core/util/http/SSEParserTest.java | 14 ++++ .../claude/20-Sharing-API-Keys.yaml | 9 ++- .../google/20-Sharing-API-Keys.yaml | 13 ++-- .../openai/20-Sharing-API-Keys.yaml | 9 ++- 59 files changed, 1109 insertions(+), 459 deletions(-) delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionDoneEvent.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMMessage.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMResponse.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockStart.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ToolUse.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/Usage.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/AbstractLLMEvent.java (68%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/ChatCompletionDoneEvent.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/ChatCompletionEvent.java (71%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/LLMGatewayInterceptor.java (83%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/ResponsesApiEvent.java (68%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/AbstractLLMErrorCreator.java (55%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMMessage.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/AbstractLLMRequest.java (76%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/AbstractLLMResponse.java (80%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMErrorCreator.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMProvider.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMResponse.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai/provider/openai/OpenAiErrorCreator.java => llmgateway/provider/chatcompletions/ChatCompletionsErrorCreator.java} (71%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai/provider/openai/OpenAiLLMChatCompletionsRequest.java => llmgateway/provider/chatcompletions/ChatCompletionsRequest.java} (53%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsResponse.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/claude/ClaudeErrorCreator.java (76%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/claude/ClaudeErrorResponse.java (78%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/claude/ClaudeLLMRequest.java (79%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/claude/ClaudeLLMResponse.java (72%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/claude/ContentBlockDelta.java (56%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ContentBlockStart.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/claude/MessageDelta.java (74%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ToolUse.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/google/GoogleErrorCreator.java (80%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/google/GoogleLLMRequest.java (82%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/google/GoogleLLMResponse.java (56%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleProvider.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/openai/AbstractOpenAiLLMRequest.java (75%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIChatCompletionsRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/openai/OpenAiLLMResponsesRequest.java (59%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/provider/openai/OpenAiLLMResponsesResponse.java (65%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiStore.java rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/store/AiApiUser.java (77%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/store/JDBCAiApiUsageStore.java (77%) rename core/src/main/java/com/predic8/membrane/core/interceptor/{ai => llmgateway}/store/SimpleAiApiStore.java (82%) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/Usage.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionDoneEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionDoneEvent.java deleted file mode 100644 index 520118262a..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionDoneEvent.java +++ /dev/null @@ -1,15 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai; - -import com.fasterxml.jackson.databind.node.NullNode; - -public class ChatCompletionDoneEvent extends AbstractLLMEvent { - - public ChatCompletionDoneEvent() { - super(NullNode.getInstance()); - } - - @Override - public String getType() { - return "chat.completion.done"; - } -} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMMessage.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMMessage.java deleted file mode 100644 index 488dabe3ce..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMMessage.java +++ /dev/null @@ -1,25 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.core.exchange.Exchange; - -public class AbstractLLMMessage { - - protected final Exchange exchange; - - public enum API { COMPLETIONS, NORMAL } - - protected API api; - - protected AbstractLLMMessage(Exchange exchange) { - this.exchange = exchange; - api = getAPI(exchange); - } - - protected API getAPI(Exchange exchange) { - if (exchange.getRequest().getUri().contains("/chat/completions")) { - return API.COMPLETIONS; - } else { - return API.NORMAL; - } - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java deleted file mode 100644 index ee06f9b7c3..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMErrorCreator.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.core.http.Response; - -import java.util.Collection; - -public interface LLMErrorCreator { - - Response invalidRequestError(String message); - - Response tokenLimitExceeded(long tokenRequired, long tokenRemaining, long tokenResetInSeconds); - - Response modelNotAllowed(String model, Collection allowedModels); - - Response authenticationFailed(); - - /** - * - * @param maxTokens as configured - * @param estimatedTokens estimated number of input tokens - * @return Response error response - */ - Response inputTokensExceeded(long maxTokens, long estimatedTokens); -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java deleted file mode 100644 index 5b52994751..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMProvider.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.core.exchange.Exchange; - -import java.util.function.Consumer; - -public interface LLMProvider { - - LLMRequest getLLMRequest(Exchange request); - LLMResponse getLLMResponse(Exchange request, Consumer postProcessor); - LLMErrorCreator getErrorCreator(); - -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java deleted file mode 100644 index a6f377686b..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMRequest.java +++ /dev/null @@ -1,29 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.fasterxml.jackson.databind.node.ObjectNode; - -import java.util.List; - -public interface LLMRequest { - - String getModel(); - - String getApiKey(); - - void setApiKey(String apiKey); - - /** - * The max number of tokens that the model is allowed to generate as specified by the client. - * @return The max number of tokens that the model is allowed to generate. -1 if no limit is set. - */ - long getRequestedMaxOutputTokens(); - - void setMaxOutputTokens(int maxOutputTokens); - - long estimateInputTokens(); - - ObjectNode getJson(); - - List getTools(); - -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMResponse.java deleted file mode 100644 index fd4979ca7e..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/LLMResponse.java +++ /dev/null @@ -1,18 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider; - -import com.predic8.membrane.core.interceptor.ai.store.Usage; -import com.predic8.membrane.core.util.http.SSEParser.SSEEvent; - -import java.util.Set; - -public interface LLMResponse { - - boolean isError(); - - Usage getUsage(); - - Set getTerminalEvents(); - - void process(SSEEvent event); - -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java deleted file mode 100644 index 99ba4820e7..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeProvider.java +++ /dev/null @@ -1,33 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider.claude; - -import com.predic8.membrane.annot.MCElement; -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.provider.LLMErrorCreator; -import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; -import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; -import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; - -import java.util.function.Consumer; - -/** - * @description (Experimental) Anthroic Claude provider configuration - * Use to configure a LLM gateway to use the anthropic API - */ -@MCElement( name="claude") -public class ClaudeProvider implements LLMProvider { - - @Override - public LLMRequest getLLMRequest(Exchange exchange) { - return new ClaudeLLMRequest(exchange); - } - - @Override - public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { - return new ClaudeLLMResponse(request, postProcessor); - } - - @Override - public LLMErrorCreator getErrorCreator() { - return new ClaudeErrorCreator(); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockStart.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockStart.java deleted file mode 100644 index b98f1d5827..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockStart.java +++ /dev/null @@ -1,23 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider.claude; - -import com.fasterxml.jackson.databind.node.ObjectNode; - -public class ContentBlockStart { - - private ToolUse toolUse; - - public static ContentBlockStart from(ObjectNode on) { - var cbs = new ContentBlockStart(); - var cb = (ObjectNode) on.path("content_block"); - - if ("tool_use".equals(cb.path("type").asText())) { - cbs.toolUse = ToolUse.from(cb); - } - - return cbs; - } - - public ToolUse getToolUse() { - return toolUse; - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ToolUse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ToolUse.java deleted file mode 100644 index 59ef545e68..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ToolUse.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider.claude; - -import com.fasterxml.jackson.databind.node.ObjectNode; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class ToolUse { - - private static final Logger log = LoggerFactory.getLogger(ToolUse.class); - - private String name; - - public static ToolUse from(ObjectNode on) { - var tu = new ToolUse(); - tu.name = on.path("name").asText(); - return tu; - } - - public String getName() { - return name; - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java deleted file mode 100644 index 4ee8860f89..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleProvider.java +++ /dev/null @@ -1,33 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider.google; - -import com.predic8.membrane.annot.MCElement; -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.provider.LLMErrorCreator; -import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; -import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; -import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; - -import java.util.function.Consumer; - -/** - * @description (Experimental)Google AI provider configuration - * Use to configure a LLM gateway to use the Google LLM API - */ -@MCElement( name="google",id = "google-ai-provider") -public class GoogleProvider implements LLMProvider { - - @Override - public LLMRequest getLLMRequest(Exchange exchange) { - return new GoogleLLMRequest(exchange); - } - - @Override - public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { - return new GoogleLLMResponse(request, postProcessor); - } - - @Override - public LLMErrorCreator getErrorCreator() { - return new GoogleErrorCreator(); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java deleted file mode 100644 index 8a1aa29436..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAIProvider.java +++ /dev/null @@ -1,44 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider.openai; - -import com.predic8.membrane.annot.MCElement; -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.provider.LLMErrorCreator; -import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; -import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; -import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; - -import java.util.function.Consumer; - -/** - * @description OpenAI provider configuration - * Use to configure a LLM gateway to use the OpenAI API - */ -@MCElement( name="openai") -public class OpenAIProvider implements LLMProvider { - - @Override - public LLMRequest getLLMRequest(Exchange exchange) { - if (isResponsesApi(exchange)) { - return new OpenAiLLMResponsesRequest(exchange); - } - - return new OpenAiLLMChatCompletionsRequest(exchange); - } - - @Override - public LLMResponse getLLMResponse(Exchange exchange, Consumer postProcessor) { - if (isResponsesApi(exchange)) { - return new OpenAiLLMResponsesResponse(exchange,postProcessor); - } - return new OpenAiChatCompletionsLLMResponse(exchange, postProcessor); - } - - @Override - public LLMErrorCreator getErrorCreator() { - return new OpenAiErrorCreator(); - } - - static boolean isResponsesApi(Exchange exchange) { - return exchange.getRequest().getUri().startsWith("/v1/responses"); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java deleted file mode 100644 index bb0a206a8c..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiChatCompletionsLLMResponse.java +++ /dev/null @@ -1,55 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.provider.openai; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractLLMEvent; -import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; -import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; -import com.predic8.membrane.core.interceptor.ai.store.Usage; -import com.predic8.membrane.core.util.http.SSEParser; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Set; -import java.util.function.Consumer; - -public class OpenAiChatCompletionsLLMResponse extends AbstractLLMResponse { - - private static final Logger log = LoggerFactory.getLogger(OpenAiChatCompletionsLLMResponse.class); - - public OpenAiChatCompletionsLLMResponse(Exchange exchange, Consumer postProcessor) { - super(exchange, postProcessor); - } - - @Override - public Usage getUsage() { - - var usage = json.path("usage"); - - var inputTokens = usage.path("prompt_tokens").asInt(0); - var outputTokens = usage.path("completion_tokens").asInt(0); - var totalTokens = usage.path("total_tokens").asInt(inputTokens + outputTokens); - - return new Usage( - inputTokens, - outputTokens, - totalTokens - ); - } - - @Override - public Set getTerminalEvents() { - return Set.of("[DONE]"); - } - - @Override - protected void processTerminalEvent(SSEParser.SSEEvent terminal) { - postProcessor.accept(OpenAiChatCompletionsLLMResponse.this); - } - - @Override - public void process(SSEParser.SSEEvent e) { - log.debug("Data: {}", e.data()); - var event = AbstractLLMEvent.create(e); - log.debug("Event: {}", event); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java deleted file mode 100644 index 73674eeef1..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiStore.java +++ /dev/null @@ -1,29 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.store; - -import com.predic8.membrane.core.router.Router; - -import java.util.Optional; - -/** - * @TODO - * - Store .status, .error, .model, .stop_reason - */ -public interface AiApiStore { - - default void init(Router router) { - } - - void store(AiApiUser user, Usage usage); - - Optional getUser(String token); - - /** - * Checks if the user has enough tokens to make the request. - * @param user The user to check - * @return Estimated number of tokens that the user has left after this request - */ - long checkLimit(AiApiUser user, long inputTokens, long outputTokens); - - long getRemainingResetTime(); -} - diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/Usage.java b/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/Usage.java deleted file mode 100644 index 9288bba508..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/Usage.java +++ /dev/null @@ -1,3 +0,0 @@ -package com.predic8.membrane.core.interceptor.ai.store; - -public record Usage(int inputTokens, int outputTokens, int totalTokens) {} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/AbstractLLMEvent.java similarity index 68% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMEvent.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/AbstractLLMEvent.java index d39b3ddafe..ed9fe0929c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/AbstractLLMEvent.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/AbstractLLMEvent.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/ChatCompletionDoneEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/ChatCompletionDoneEvent.java new file mode 100644 index 0000000000..cc234b8113 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/ChatCompletionDoneEvent.java @@ -0,0 +1,29 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway; + +import com.fasterxml.jackson.databind.node.NullNode; + +public class ChatCompletionDoneEvent extends AbstractLLMEvent { + + public ChatCompletionDoneEvent() { + super(NullNode.getInstance()); + } + + @Override + public String getType() { + return "chat.completion.done"; + } +} \ No newline at end of file diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/ChatCompletionEvent.java similarity index 71% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/ChatCompletionEvent.java index 86531144d9..1fde1e736f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ChatCompletionEvent.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/ChatCompletionEvent.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway; import com.fasterxml.jackson.databind.JsonNode; import org.slf4j.Logger; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java similarity index 83% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java index 89cc3d2ede..3725415684 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway; import com.predic8.membrane.annot.MCAttribute; import com.predic8.membrane.annot.MCChildElement; @@ -6,17 +20,15 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.AbstractInterceptor; import com.predic8.membrane.core.interceptor.Outcome; -import com.predic8.membrane.core.interceptor.ai.provider.LLMErrorCreator; -import com.predic8.membrane.core.interceptor.ai.provider.LLMProvider; -import com.predic8.membrane.core.interceptor.ai.provider.LLMRequest; -import com.predic8.membrane.core.interceptor.ai.store.AiApiStore; -import com.predic8.membrane.core.interceptor.ai.store.AiApiUser; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMProvider; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.store.AiApiStore; +import com.predic8.membrane.core.interceptor.llmgateway.store.AiApiUser; import com.predic8.membrane.core.util.ConfigurationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.List; - import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; import static com.predic8.membrane.core.interceptor.Outcome.RETURN; import static com.predic8.membrane.core.util.json.JsonUtil.setJsonBody; @@ -47,7 +59,8 @@ public class LLMGatewayInterceptor extends AbstractInterceptor { private String apiKey; private int maxOutputTokens; private int maxInputTokens; - private List models; + + private Policies policies = new Policies(); private AiApiStore store; @@ -135,10 +148,10 @@ public Outcome handleRequest(Exchange exc) { } } - if (models != null) { + if (policies.getModels() != null) { var model = aiReq.getModel(); - if (!models.contains(model)) { - exc.setResponse(errorCreator.modelNotAllowed(model, models)); + if (!policies.getModels().contains(model)) { + exc.setResponse(errorCreator.modelNotAllowed(model, policies.getModels())); return RETURN; } } @@ -157,14 +170,9 @@ long computeEffectiveMaxOutputTokens(long requestedMaxOutputTokens, long maxOutp @Override public Outcome handleResponse(Exchange exc) { - provider.getLLMResponse(exc, res -> { var user = exc.getProperty(MEMBRANE_AI_USER, AiApiUser.class); - if (log.isDebugEnabled() && user != null) { - log.debug("Token usage of user {}: {}", user, res.getUsage()); - } else { - log.info("Token usage: {}", res.getUsage()); - } + log.debug("Token usage of user {}: {}", user, res.getUsage()); if (store != null) { store.store(user, res.getUsage()); } @@ -196,7 +204,7 @@ public AiApiStore getAiStore() { * A store is needed for user authentication at the gateway. * The gateway will use the store to enforce token limits and log usage statistics. */ - @MCChildElement(allowForeign = true, order = 10) + @MCChildElement(allowForeign = true, order = 30) public void setAiStore(AiApiStore store) { this.store = store; } @@ -235,19 +243,6 @@ public void setMaxInputTokens(int maxInputTokens) { this.maxInputTokens = maxInputTokens; } - public List getModels() { - return models; - } - - /** - * @param models List of models that can be used by the gateway. - * @desciption Restricts the models that can be used by the gateway. - * @default null (no restriction) - */ - @MCAttribute - public void setModels(List models) { - this.models = models; - } public LLMProvider getProvider() { return provider; @@ -258,8 +253,21 @@ public LLMProvider getProvider() { * @description The LLM provider to use. Currently, OpenAI, Anthropic and Gemini are supported. * The provider determines the API used to talk to the LLM. The provider can be different as long as the API is supported. */ - @MCChildElement(allowForeign = true) + @MCChildElement(order = 10) public void setProvider(LLMProvider provider) { this.provider = provider; } + + public Policies getPolicies() { + return policies; + } + + /** + * + * @param policies Usage policy for the LLM Gateway. + */ + @MCChildElement(order = 20) + public void setPolicies(Policies policies) { + this.policies = policies; + } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/ResponsesApiEvent.java similarity index 68% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/ResponsesApiEvent.java index af2a351dc6..4b726bec62 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/ResponsesApiEvent.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/ResponsesApiEvent.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMErrorCreator.java similarity index 55% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMErrorCreator.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMErrorCreator.java index 6e6d739711..6ecf4d7ef5 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMErrorCreator.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai.provider; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMMessage.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMMessage.java new file mode 100644 index 0000000000..391324f38e --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMMessage.java @@ -0,0 +1,39 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider; + +import com.predic8.membrane.core.exchange.Exchange; + +public class AbstractLLMMessage { + + protected final Exchange exchange; + + public enum API { COMPLETIONS, NORMAL } + + protected API api; + + protected AbstractLLMMessage(Exchange exchange) { + this.exchange = exchange; + api = getAPI(exchange); + } + + protected API getAPI(Exchange exchange) { + if (exchange.getRequest().getUri().contains("/chat/completions")) { + return API.COMPLETIONS; + } else { + return API.NORMAL; + } + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java similarity index 76% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java index 95ecc7a77e..f5955d6acb 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai.provider; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMResponse.java similarity index 80% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMResponse.java index 4967d34af4..4732d0a0a5 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/AbstractLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMResponse.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai.provider; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.JsonNodeFactory; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMErrorCreator.java new file mode 100644 index 0000000000..732a1332fe --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMErrorCreator.java @@ -0,0 +1,38 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider; + +import com.predic8.membrane.core.http.Response; + +import java.util.Collection; + +public interface LLMErrorCreator { + + Response invalidRequestError(String message); + + Response tokenLimitExceeded(long tokenRequired, long tokenRemaining, long tokenResetInSeconds); + + Response modelNotAllowed(String model, Collection allowedModels); + + Response authenticationFailed(); + + /** + * + * @param maxTokens as configured + * @param estimatedTokens estimated number of input tokens + * @return Response error response + */ + Response inputTokensExceeded(long maxTokens, long estimatedTokens); +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMProvider.java new file mode 100644 index 0000000000..1fb2fc4eae --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMProvider.java @@ -0,0 +1,27 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider; + +import com.predic8.membrane.core.exchange.Exchange; + +import java.util.function.Consumer; + +public interface LLMProvider { + + LLMRequest getLLMRequest(Exchange request); + LLMResponse getLLMResponse(Exchange request, Consumer postProcessor); + LLMErrorCreator getErrorCreator(); + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java new file mode 100644 index 0000000000..371115e911 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java @@ -0,0 +1,43 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider; + +import com.fasterxml.jackson.databind.node.ObjectNode; + +import java.util.List; + +public interface LLMRequest { + + String getModel(); + + String getApiKey(); + + void setApiKey(String apiKey); + + /** + * The max number of tokens that the model is allowed to generate as specified by the client. + * @return The max number of tokens that the model is allowed to generate. -1 if no limit is set. + */ + long getRequestedMaxOutputTokens(); + + void setMaxOutputTokens(int maxOutputTokens); + + long estimateInputTokens(); + + ObjectNode getJson(); + + List getTools(); + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMResponse.java new file mode 100644 index 0000000000..3d3ed9bd78 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMResponse.java @@ -0,0 +1,32 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider; + +import com.predic8.membrane.core.interceptor.llmgateway.store.Usage; +import com.predic8.membrane.core.util.http.SSEParser.SSEEvent; + +import java.util.Set; + +public interface LLMResponse { + + boolean isError(); + + Usage getUsage(); + + Set getTerminalEvents(); + + void process(SSEEvent event); + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsErrorCreator.java similarity index 71% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsErrorCreator.java index 7f51494ad5..643786b0a4 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsErrorCreator.java @@ -1,14 +1,28 @@ -package com.predic8.membrane.core.interceptor.ai.provider.openai; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions; import com.predic8.membrane.core.http.Response; -import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMErrorCreator; import java.util.Collection; import static com.predic8.membrane.core.http.Header.WWW_AUTHENTICATE; import static com.predic8.membrane.core.http.Response.*; -public class OpenAiErrorCreator extends AbstractLLMErrorCreator { +public class ChatCompletionsErrorCreator extends AbstractLLMErrorCreator { @Override public Response invalidRequestError(String message) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java new file mode 100644 index 0000000000..1ac5be3699 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java @@ -0,0 +1,62 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions; + +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMProvider; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; + +import java.util.function.Consumer; + +/** + * @description + * OpenAI Chat Completions API compatible provider. + * Can be used for the following providers: + *

    + *
  • Azure OpenAI
  • + *
  • Google Gemini (OpenAI compatible endpoint)
  • + *
  • TogetherAI
  • + *
  • Fireworks AI
  • + *
  • DeepSeek AI
  • + *
  • OpenRouter
  • + *
  • Mistral AI
  • + *
  • DeepInfra
  • + *
  • SiliconFlow
  • + *
  • NVIDIA NIM
  • + *
  • ML Studio
  • + *
  • vLLM
  • + *
  • Ollama
  • + *
+ */ +@MCElement(name = "chatCompletions") +public class ChatCompletionsProvider implements LLMProvider { + @Override + public LLMRequest getLLMRequest(Exchange request) { + return new ChatCompletionsRequest(request); + } + + @Override + public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { + return new ChatCompletionsResponse(request, postProcessor); + } + + @Override + public LLMErrorCreator getErrorCreator() { + return null; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java similarity index 53% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java index 5c57339682..4ecbf9065a 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMChatCompletionsRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java @@ -1,14 +1,29 @@ -package com.predic8.membrane.core.interceptor.ai.provider.openai; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions; import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.AbstractOpenAiLLMRequest; import java.util.List; import static java.util.Collections.emptyList; -public class OpenAiLLMChatCompletionsRequest extends AbstractOpenAiLLMRequest { +public class ChatCompletionsRequest extends AbstractOpenAiLLMRequest { - public OpenAiLLMChatCompletionsRequest(Exchange exchange) { + public ChatCompletionsRequest(Exchange exchange) { super(exchange); if (json == null) { @@ -24,7 +39,7 @@ public OpenAiLLMChatCompletionsRequest(Exchange exchange) { @Override public void setMaxOutputTokens(int maxOutputTokens) { - json.put("max_completion_tokens", maxOutputTokens); + json.put("max_tokens", maxOutputTokens); } public List getTools() { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsResponse.java new file mode 100644 index 0000000000..2b1acc0047 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsResponse.java @@ -0,0 +1,69 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.AbstractLLMEvent; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; +import com.predic8.membrane.core.interceptor.llmgateway.store.Usage; +import com.predic8.membrane.core.util.http.SSEParser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Set; +import java.util.function.Consumer; + +public class ChatCompletionsResponse extends AbstractLLMResponse { + + private static final Logger log = LoggerFactory.getLogger(ChatCompletionsResponse.class); + + public ChatCompletionsResponse(Exchange exchange, Consumer postProcessor) { + super(exchange, postProcessor); + } + + @Override + public Usage getUsage() { + + var usage = json.path("usage"); + + var inputTokens = usage.path("prompt_tokens").asInt(0); + var outputTokens = usage.path("completion_tokens").asInt(0); + var totalTokens = usage.path("total_tokens").asInt(inputTokens + outputTokens); + + return new Usage( + inputTokens, + outputTokens, + totalTokens + ); + } + + @Override + public Set getTerminalEvents() { + return Set.of("[DONE]"); + } + + @Override + protected void processTerminalEvent(SSEParser.SSEEvent terminal) { + postProcessor.accept(ChatCompletionsResponse.this); + } + + @Override + public void process(SSEParser.SSEEvent e) { + log.debug("Data: {}", e.data()); + var event = AbstractLLMEvent.create(e); + log.debug("Event: {}", event); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeErrorCreator.java similarity index 76% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeErrorCreator.java index 506b6d2083..1fbcf2f1a1 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeErrorCreator.java @@ -1,8 +1,22 @@ -package com.predic8.membrane.core.interceptor.ai.provider.claude; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.claude; import com.predic8.membrane.core.http.Response; -import com.predic8.membrane.core.interceptor.ai.provider.LLMErrorCreator; -import com.predic8.membrane.core.interceptor.ai.provider.claude.ClaudeErrorResponse.ClaudeError; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.claude.ClaudeErrorResponse.ClaudeError; import java.util.Collection; import java.util.UUID; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeErrorResponse.java similarity index 78% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorResponse.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeErrorResponse.java index 2bdce96c2e..0ff004834e 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeErrorResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeErrorResponse.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai.provider.claude; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.claude; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.core.JsonProcessingException; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java similarity index 79% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java index 2a1151e855..fa5279afe4 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java @@ -1,8 +1,22 @@ -package com.predic8.membrane.core.interceptor.ai.provider.claude; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.claude; import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMResponse.java similarity index 72% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMResponse.java index f487d43b0d..8d534643ea 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ClaudeLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMResponse.java @@ -1,10 +1,24 @@ -package com.predic8.membrane.core.interceptor.ai.provider.claude; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.claude; import com.fasterxml.jackson.databind.JsonNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; -import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; -import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; +import com.predic8.membrane.core.interceptor.llmgateway.store.Usage; import com.predic8.membrane.core.util.http.SSEParser.SSEEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java new file mode 100644 index 0000000000..a296575058 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java @@ -0,0 +1,47 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.claude; + +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMProvider; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; + +import java.util.function.Consumer; + +/** + * @description (Experimental) Anthroic Claude provider configuration + * Use to configure a LLM gateway to use the anthropic API + */ +@MCElement( name="claude") +public class ClaudeProvider implements LLMProvider { + + @Override + public LLMRequest getLLMRequest(Exchange exchange) { + return new ClaudeLLMRequest(exchange); + } + + @Override + public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { + return new ClaudeLLMResponse(request, postProcessor); + } + + @Override + public LLMErrorCreator getErrorCreator() { + return new ClaudeErrorCreator(); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockDelta.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ContentBlockDelta.java similarity index 56% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockDelta.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ContentBlockDelta.java index f8d32c1c97..5e5a0648bb 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/ContentBlockDelta.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ContentBlockDelta.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai.provider.claude; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.claude; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ContentBlockStart.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ContentBlockStart.java new file mode 100644 index 0000000000..bdf2be207b --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ContentBlockStart.java @@ -0,0 +1,37 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.claude; + +import com.fasterxml.jackson.databind.node.ObjectNode; + +public class ContentBlockStart { + + private ToolUse toolUse; + + public static ContentBlockStart from(ObjectNode on) { + var cbs = new ContentBlockStart(); + var cb = (ObjectNode) on.path("content_block"); + + if ("tool_use".equals(cb.path("type").asText())) { + cbs.toolUse = ToolUse.from(cb); + } + + return cbs; + } + + public ToolUse getToolUse() { + return toolUse; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/MessageDelta.java similarity index 74% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/MessageDelta.java index a99b04e2bb..4aa68fa737 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/claude/MessageDelta.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/MessageDelta.java @@ -1,8 +1,22 @@ -package com.predic8.membrane.core.interceptor.ai.provider.claude; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.claude; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; -import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.interceptor.llmgateway.store.Usage; public class MessageDelta { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ToolUse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ToolUse.java new file mode 100644 index 0000000000..5694468d9e --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ToolUse.java @@ -0,0 +1,36 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.claude; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ToolUse { + + private static final Logger log = LoggerFactory.getLogger(ToolUse.class); + + private String name; + + public static ToolUse from(ObjectNode on) { + var tu = new ToolUse(); + tu.name = on.path("name").asText(); + return tu; + } + + public String getName() { + return name; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleErrorCreator.java similarity index 80% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleErrorCreator.java index 92a115194a..1b86f0f39b 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleErrorCreator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleErrorCreator.java @@ -1,7 +1,21 @@ -package com.predic8.membrane.core.interceptor.ai.provider.google; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.google; import com.predic8.membrane.core.http.Response; -import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMErrorCreator; import java.util.Collection; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java similarity index 82% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java index 07da55b089..bd60b10617 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java @@ -1,9 +1,23 @@ -package com.predic8.membrane.core.interceptor.ai.provider.google; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.google; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; public class GoogleLLMRequest extends AbstractLLMRequest { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMResponse.java similarity index 56% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMResponse.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMResponse.java index db04ae85df..abf1c0a592 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/google/GoogleLLMResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMResponse.java @@ -1,9 +1,23 @@ -package com.predic8.membrane.core.interceptor.ai.provider.google; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.google; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; -import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; -import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; +import com.predic8.membrane.core.interceptor.llmgateway.store.Usage; import com.predic8.membrane.core.util.http.SSEParser; import java.util.Set; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleProvider.java new file mode 100644 index 0000000000..b1b36ea1df --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleProvider.java @@ -0,0 +1,47 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.google; + +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMProvider; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; + +import java.util.function.Consumer; + +/** + * @description (Experimental)Google AI provider configuration + * Use to configure a LLM gateway to use the Google LLM API + */ +@MCElement( name="google",id = "google-ai-provider") +public class GoogleProvider implements LLMProvider { + + @Override + public LLMRequest getLLMRequest(Exchange exchange) { + return new GoogleLLMRequest(exchange); + } + + @Override + public LLMResponse getLLMResponse(Exchange request, Consumer postProcessor) { + return new GoogleLLMResponse(request, postProcessor); + } + + @Override + public LLMErrorCreator getErrorCreator() { + return new GoogleErrorCreator(); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/AbstractOpenAiLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java similarity index 75% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/AbstractOpenAiLLMRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java index 7b8a76e4d1..b49e7440fc 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/AbstractOpenAiLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java @@ -1,8 +1,22 @@ -package com.predic8.membrane.core.interceptor.ai.provider.openai; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; import com.fasterxml.jackson.databind.JsonNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; public abstract class AbstractOpenAiLLMRequest extends AbstractLLMRequest { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIChatCompletionsRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIChatCompletionsRequest.java new file mode 100644 index 0000000000..8c6e474398 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIChatCompletionsRequest.java @@ -0,0 +1,29 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions.ChatCompletionsRequest; + +public class OpenAIChatCompletionsRequest extends ChatCompletionsRequest { + public OpenAIChatCompletionsRequest(Exchange exchange) { + super(exchange); + } + + @Override + public void setMaxOutputTokens(int maxOutputTokens) { + json.put("max_completion_tokens", maxOutputTokens); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java new file mode 100644 index 0000000000..e55d40bd47 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java @@ -0,0 +1,59 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; + +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMProvider; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; +import com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions.ChatCompletionsErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions.ChatCompletionsResponse; + +import java.util.function.Consumer; + +/** + * @description OpenAI provider configuration + * Use to configure a LLM gateway to use the OpenAI API + */ +@MCElement( name="openai") +public class OpenAIProvider implements LLMProvider { + + @Override + public LLMRequest getLLMRequest(Exchange exchange) { + if (isResponsesApi(exchange)) { + return new OpenAiLLMResponsesRequest(exchange); + } + return new OpenAIChatCompletionsRequest(exchange); + } + + @Override + public LLMResponse getLLMResponse(Exchange exchange, Consumer postProcessor) { + if (isResponsesApi(exchange)) { + return new OpenAiLLMResponsesResponse(exchange,postProcessor); + } + return new ChatCompletionsResponse(exchange, postProcessor); + } + + @Override + public LLMErrorCreator getErrorCreator() { + return new ChatCompletionsErrorCreator(); + } + + static boolean isResponsesApi(Exchange exchange) { + return exchange.getRequest().getUri().startsWith("/v1/responses"); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java similarity index 59% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java index 2568755848..3caa187c88 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai.provider.openai; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; import com.predic8.membrane.core.exchange.Exchange; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesResponse.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesResponse.java similarity index 65% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesResponse.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesResponse.java index 67e836ebd6..15263fbd55 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/provider/openai/OpenAiLLMResponsesResponse.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesResponse.java @@ -1,11 +1,25 @@ -package com.predic8.membrane.core.interceptor.ai.provider.openai; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.ai.AbstractLLMEvent; -import com.predic8.membrane.core.interceptor.ai.provider.AbstractLLMResponse; -import com.predic8.membrane.core.interceptor.ai.provider.LLMResponse; -import com.predic8.membrane.core.interceptor.ai.store.Usage; +import com.predic8.membrane.core.interceptor.llmgateway.AbstractLLMEvent; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMResponse; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; +import com.predic8.membrane.core.interceptor.llmgateway.store.Usage; import com.predic8.membrane.core.util.http.SSEParser; import com.predic8.membrane.core.util.json.JsonUtil; import org.slf4j.Logger; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiStore.java new file mode 100644 index 0000000000..c764e17ac9 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiStore.java @@ -0,0 +1,43 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.store; + +import com.predic8.membrane.core.router.Router; + +import java.util.Optional; + +/** + * @TODO + * - Store .status, .error, .model, .stop_reason + */ +public interface AiApiStore { + + default void init(Router router) { + } + + void store(AiApiUser user, Usage usage); + + Optional getUser(String token); + + /** + * Checks if the user has enough tokens to make the request. + * @param user The user to check + * @return Estimated number of tokens that the user has left after this request + */ + long checkLimit(AiApiUser user, long inputTokens, long outputTokens); + + long getRemainingResetTime(); +} + diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiUser.java similarity index 77% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiUser.java index cd3ab76b4b..da8b792680 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/AiApiUser.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiUser.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai.store; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.store; import com.predic8.membrane.annot.MCAttribute; import com.predic8.membrane.annot.MCElement; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/JDBCAiApiUsageStore.java similarity index 77% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/JDBCAiApiUsageStore.java index 16457a97db..7541c08a2c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/JDBCAiApiUsageStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/JDBCAiApiUsageStore.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai.store; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.store; import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.router.Router; @@ -41,7 +55,7 @@ public void init(Router router) { } @Override - public void store(AiApiUser user, com.predic8.membrane.core.interceptor.ai.store.Usage usage) { + public void store(AiApiUser user, com.predic8.membrane.core.interceptor.llmgateway.store.Usage usage) { try (var connection = getConnection(); var ps = connection.prepareStatement(INSERT_SQL)) { ps.setString(1, user.getName()); ps.setInt(2, usage.inputTokens()); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/SimpleAiApiStore.java similarity index 82% rename from core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/SimpleAiApiStore.java index f9b8218608..106892c39f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/ai/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/SimpleAiApiStore.java @@ -1,4 +1,18 @@ -package com.predic8.membrane.core.interceptor.ai.store; +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.store; import com.predic8.membrane.annot.MCAttribute; import com.predic8.membrane.annot.MCChildElement; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/Usage.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/Usage.java new file mode 100644 index 0000000000..3bcc626858 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/Usage.java @@ -0,0 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.store; + +public record Usage(int inputTokens, int outputTokens, int totalTokens) {} diff --git a/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java b/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java index acddbc7428..405312ba4e 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java +++ b/core/src/main/java/com/predic8/membrane/core/util/http/SSEParser.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.util.http; import com.fasterxml.jackson.core.JsonProcessingException; diff --git a/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java b/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java index c08ecd3a09..7738321f85 100644 --- a/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java +++ b/core/src/test/java/com/predic8/membrane/core/util/http/SSEParserTest.java @@ -1,3 +1,17 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + package com.predic8.membrane.core.util.http; import com.predic8.membrane.core.http.Chunk; diff --git a/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml index 44bc28a8ee..e3550da714 100644 --- a/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml +++ b/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml @@ -35,10 +35,11 @@ api: # Limits per request maxInputTokens: 100 maxOutputTokens: 200 - models: - - claude-sonnet-4-0 - - claude-opus-4-0 - - claude-haiku-3-5 + policies: + models: + - claude-sonnet-4-0 + - claude-opus-4-0 + - claude-haiku-3-5 simpleStore: # User-facing API keys for the LLM Gateway users: diff --git a/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml index 0b7c8569e7..2b6e344edd 100644 --- a/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml +++ b/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml @@ -35,12 +35,13 @@ api: # Limits per request maxInputTokens: 100 maxOutputTokens: 200 - models: - - gemini-2.5-pro - - gemini-2.5-flash - - gemini-2.5-flash-lite - - gemini-2.0-flash - - gemini-2.0-flash-lite + policies: + models: + - gemini-2.5-pro + - gemini-2.5-flash + - gemini-2.5-flash-lite + - gemini-2.0-flash + - gemini-2.0-flash-lite simpleStore: # User-facing API keys for the LLM Gateway users: diff --git a/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml index b0231ca3d8..19f8295c69 100644 --- a/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml +++ b/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml @@ -34,10 +34,11 @@ api: # Limits per request maxInputTokens: 100 maxOutputTokens: 200 - models: - - gpt-5.4 - - gpt-5-nano - - gpt-5-mini + policies: + models: + - gpt-5.4 + - gpt-5-nano + - gpt-5-mini openai: {} simpleStore: # User-facing API keys for the LLM Gateway From b644a65df7a7782b8f099821f2ca72f6a22ca0c0 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 26 May 2026 15:51:50 +0200 Subject: [PATCH 36/43] feat: introduce `Policies` class and update LLM Gateway to support policy-based token and model restrictions - Added `Policies` class for defining restrictions on models, input tokens, and output tokens in the LLM Gateway. - Replaced `maxInputTokens` and `maxOutputTokens` fields in `LLMGatewayInterceptor` with `Policies`. - Updated YAML tutorials (OpenAI, Claude, Google) to use the new `policies` configuration. --- .../llmgateway/LLMGatewayInterceptor.java | 54 +++---------- .../core/interceptor/llmgateway/Policies.java | 76 +++++++++++++++++++ .../claude/10-Basic-LLM-Gateway.yaml | 5 +- .../claude/20-Sharing-API-Keys.yaml | 14 ++-- .../google/10-Basic-LLM-Gateway.yaml | 5 +- .../google/20-Sharing-API-Keys.yaml | 6 +- .../openai/10-Basic-LLM-Gateway.yaml | 5 +- .../openai/20-Sharing-API-Keys.yaml | 6 +- 8 files changed, 109 insertions(+), 62 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java index 89197b079b..2824842f67 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java @@ -57,8 +57,6 @@ public class LLMGatewayInterceptor extends AbstractInterceptor { private LLMErrorCreator errorCreator; private String apiKey; - private int maxOutputTokens; - private int maxInputTokens; private Policies policies = new Policies(); @@ -111,7 +109,7 @@ public Outcome handleRequest(Exchange exc) { // Check store limits if (store != null) { - var effectiveMaxTokens = computeEffectiveMaxOutputTokens(aiReq.getRequestedMaxOutputTokens(), maxOutputTokens); + var effectiveMaxTokens = computeEffectiveMaxOutputTokens(aiReq.getRequestedMaxOutputTokens(), policies.getMaxOutputTokens()); var remaining = store.checkLimit(user, inputTokens, effectiveMaxTokens); log.debug("User {} has {} remaining tokens left", user, remaining); if (remaining <= 0) { @@ -130,20 +128,20 @@ public Outcome handleRequest(Exchange exc) { var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); - if (maxOutputTokens > 0) { + if (policies.getMaxOutputTokens() > 0) { if (requestedMaxOutputTokens <= 0) { - log.info("No max. output requested. Setting limit to {}.", maxOutputTokens); - aiReq.setMaxOutputTokens(maxOutputTokens); - } else if (requestedMaxOutputTokens > maxOutputTokens) { - log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.", requestedMaxOutputTokens, maxOutputTokens); - aiReq.setMaxOutputTokens(maxOutputTokens); + log.info("No max. output requested. Setting limit to {}.", policies.getMaxOutputTokens()); + aiReq.setMaxOutputTokens(policies.getMaxOutputTokens()); + } else if (requestedMaxOutputTokens > policies.getMaxOutputTokens()) { + log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.", requestedMaxOutputTokens, policies.getMaxOutputTokens()); + aiReq.setMaxOutputTokens(policies.getMaxOutputTokens()); } } - if (maxInputTokens != 0) { - if (inputTokens > maxInputTokens) { - log.info("Input tokens {} exceed the limit of {}.", inputTokens, maxInputTokens); - exc.setResponse(errorCreator.inputTokensExceeded(maxInputTokens, inputTokens)); + if (policies.getMaxInputTokens() != 0) { + if (inputTokens > policies.getMaxInputTokens()) { + log.info("Input tokens {} exceed the limit of {}.", inputTokens, policies.getMaxInputTokens()); + exc.setResponse(errorCreator.inputTokensExceeded(policies.getMaxInputTokens(), inputTokens)); return RETURN; } } @@ -214,36 +212,6 @@ public String getDisplayName() { return "LLM Gateway"; } - public int getMaxOutputTokens() { - return maxOutputTokens; - } - - /** - * @param maxOutputTokens Maximum number of tokens the LLM should use to generate a response. - * @description Maximum number of tokens the LLM should use to generate a response. This is just a hint that the gateway - * sends to the LLM provider. The provider may use a different limit. - * @default 0 (unlimited) - */ - @MCAttribute - public void setMaxOutputTokens(int maxOutputTokens) { - this.maxOutputTokens = maxOutputTokens; - } - - public int getMaxInputTokens() { - return maxInputTokens; - } - - /** - * @param maxInputTokens Maximum number of tokens that a request can use. - * @description Restricts token usage for the input. The size of the input is estimated by gateway based on the request size. - * Actual token usage may be deviate from this value. - */ - @MCAttribute - public void setMaxInputTokens(int maxInputTokens) { - this.maxInputTokens = maxInputTokens; - } - - public LLMProvider getProvider() { return provider; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java new file mode 100644 index 0000000000..cfbf960e10 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java @@ -0,0 +1,76 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway; + +import com.predic8.membrane.annot.MCAttribute; +import com.predic8.membrane.annot.MCElement; + +import java.util.List; + +/** + * LLM Gateway policies for token usage and model restrictions. + */ +@MCElement(name = "policies", topLevel = false, id="llm-gateway-policies") +public class Policies { + + private List models; + private int maxOutputTokens; + private int maxInputTokens; + + public List getModels() { + return models; + } + + /** + * @param models List of models that can be used by the gateway. + * @desciption Restricts the models that can be used by the gateway. + * @default null (no restriction) + */ + @MCAttribute + public void setModels(List models) { + this.models = models; + } + + + public int getMaxOutputTokens() { + return maxOutputTokens; + } + + /** + * @param maxOutputTokens Maximum number of tokens the LLM should use to generate a response. + * @description Maximum number of tokens the LLM should use to generate a response. This is just a hint that the gateway + * sends to the LLM provider. The provider may use a different limit. + * @default 0 (unlimited) + */ + @MCAttribute + public void setMaxOutputTokens(int maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + } + + public int getMaxInputTokens() { + return maxInputTokens; + } + + /** + * @param maxInputTokens Maximum number of tokens that a request can use. + * @description Restricts token usage for the input. The size of the input is estimated by gateway based on the request size. + * Actual token usage may be deviate from this value. + */ + @MCAttribute + public void setMaxInputTokens(int maxInputTokens) { + this.maxInputTokens = maxInputTokens; + } + +} diff --git a/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml index 5b30514e2f..ddaaaedcf1 100644 --- a/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml +++ b/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml @@ -21,7 +21,8 @@ api: flow: - llmGateway: claude: {} - maxInputTokens: 100 - maxOutputTokens: 200 + policies: + maxInputTokens: 100 + maxOutputTokens: 200 target: url: https://api.anthropic.com \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml index e3550da714..3a6a54f2f4 100644 --- a/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml +++ b/distribution/tutorials/ai/llm-gateway/claude/20-Sharing-API-Keys.yaml @@ -32,14 +32,14 @@ api: - llmGateway: claude: {} apiKey: <> - # Limits per request - maxInputTokens: 100 - maxOutputTokens: 200 policies: - models: - - claude-sonnet-4-0 - - claude-opus-4-0 - - claude-haiku-3-5 + # Limits per request + maxInputTokens: 100 + maxOutputTokens: 200 + models: + - claude-sonnet-4-0 + - claude-opus-4-0 + - claude-haiku-3-5 simpleStore: # User-facing API keys for the LLM Gateway users: diff --git a/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml index a86eec6a27..2cbf4c236d 100644 --- a/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml +++ b/distribution/tutorials/ai/llm-gateway/google/10-Basic-LLM-Gateway.yaml @@ -21,7 +21,8 @@ api: flow: - llmGateway: google: {} - maxInputTokens: 100 - maxOutputTokens: 200 + policies: + maxInputTokens: 100 + maxOutputTokens: 200 target: url: https://generativelanguage.googleapis.com \ No newline at end of file diff --git a/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml index 2b6e344edd..4a9ef00ba4 100644 --- a/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml +++ b/distribution/tutorials/ai/llm-gateway/google/20-Sharing-API-Keys.yaml @@ -32,10 +32,10 @@ api: - llmGateway: google: {} apiKey: <> - # Limits per request - maxInputTokens: 100 - maxOutputTokens: 200 policies: + # Limits per request + maxInputTokens: 100 + maxOutputTokens: 200 models: - gemini-2.5-pro - gemini-2.5-flash diff --git a/distribution/tutorials/ai/llm-gateway/openai/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/openai/10-Basic-LLM-Gateway.yaml index 07ce7c4aff..0074494b40 100644 --- a/distribution/tutorials/ai/llm-gateway/openai/10-Basic-LLM-Gateway.yaml +++ b/distribution/tutorials/ai/llm-gateway/openai/10-Basic-LLM-Gateway.yaml @@ -20,7 +20,8 @@ api: flow: - llmGateway: openai: {} - maxInputTokens: 100 - maxOutputTokens: 200 + policies: + maxInputTokens: 100 + maxOutputTokens: 200 target: url: https://api.openai.com diff --git a/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml b/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml index 19f8295c69..8aa3e72f4d 100644 --- a/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml +++ b/distribution/tutorials/ai/llm-gateway/openai/20-Sharing-API-Keys.yaml @@ -31,10 +31,10 @@ api: flow: - llmGateway: apiKey: <> - # Limits per request - maxInputTokens: 100 - maxOutputTokens: 200 policies: + # Limits per request + maxInputTokens: 100 + maxOutputTokens: 200 models: - gpt-5.4 - gpt-5-nano From 2449491d5c69ef0a5bb131254105a91eedb84307 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 27 May 2026 16:04:38 +0200 Subject: [PATCH 37/43] feat: refactor policies and introduce system prompt support - Replaced `Policies` class implementation with `DefaultPolicies` and `NullPolicies` for enhanced flexibility. - Added `SystemPrompt` class to support dynamic system prompt management in LLM Gateway. - Updated `LLMGatewayInterceptor` to delegate policy enforcement and system prompt handling to respective components. - Extended providers (OpenAI, Claude, Google Gemini) with standardized system prompt methods (`getSystemPrompt`, `setSystemPrompt`, `removeSystemPrompt`). - Enhanced test coverage with `AbstractLLMRequestTest` for API key handling and bearer token case insensitivity. --- .../llmgateway/DefaultPolicies.java | 126 ++++++++++++++++++ .../llmgateway/LLMGatewayInterceptor.java | 46 +++---- .../interceptor/llmgateway/NullPolicies.java | 31 +++++ .../core/interceptor/llmgateway/Policies.java | 77 ++--------- .../interceptor/llmgateway/SystemPrompt.java | 70 ++++++++++ .../provider/AbstractLLMRequest.java | 5 +- .../llmgateway/provider/LLMRequest.java | 8 ++ .../ChatCompletionsRequest.java | 60 +++++++++ .../provider/claude/ClaudeLLMRequest.java | 38 ++++++ .../provider/google/GoogleLLMRequest.java | 45 +++++++ .../openai/OpenAiLLMResponsesRequest.java | 31 +++++ .../provider/AbstractLLMRequestTest.java | 55 ++++++++ 12 files changed, 495 insertions(+), 97 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/NullPolicies.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java create mode 100644 core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java new file mode 100644 index 0000000000..184b591136 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java @@ -0,0 +1,126 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway; + +import com.predic8.membrane.annot.MCAttribute; +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; +import static com.predic8.membrane.core.interceptor.Outcome.RETURN; + +/** + * @description LLM Gateway policies for token usage and model restrictions. + */ +@MCElement(name = "policies", id="llm-gateway-policies") +public class DefaultPolicies implements Policies { + + private static final Logger log = LoggerFactory.getLogger(LLMGatewayInterceptor.class); + + private LLMErrorCreator errorCreator; + + private List models; + private int maxOutputTokens; + private int maxInputTokens; + + public void init(LLMErrorCreator errorCreator) { + this.errorCreator = errorCreator; + } + + public Outcome handleRequest(LLMRequest aiReq, Exchange exc) { + + var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); + var inputTokens = aiReq.estimateInputTokens(); + + if (maxOutputTokens > 0) { + if (requestedMaxOutputTokens <= 0) { + log.info("No max. output requested. Setting limit to {}.", maxOutputTokens); + aiReq.setMaxOutputTokens(maxOutputTokens); + } else if (requestedMaxOutputTokens > maxOutputTokens) { + log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.", requestedMaxOutputTokens, maxOutputTokens); + aiReq.setMaxOutputTokens(maxOutputTokens); + } + } + + if (maxInputTokens != 0) { + if (inputTokens > maxInputTokens) { + log.info("Input tokens {} exceed the limit of {}.", inputTokens, maxInputTokens); + exc.setResponse(errorCreator.inputTokensExceeded(maxInputTokens, inputTokens)); + return RETURN; + } + } + + if (models != null) { + var model = aiReq.getModel(); + if (!models.contains(model)) { + exc.setResponse(errorCreator.modelNotAllowed(model, models)); + return RETURN; + } + } + + return CONTINUE; + } + + public List getModels() { + return models; + } + + /** + * @param models List of models that can be used by the gateway. + * @desciption Restricts the models that can be used by the gateway. + * @default null (no restriction) + */ + @MCAttribute + public void setModels(List models) { + this.models = models; + } + + + public int getMaxOutputTokens() { + return maxOutputTokens; + } + + /** + * @param maxOutputTokens Maximum number of tokens the LLM should use to generate a response. + * @description Maximum number of tokens the LLM should use to generate a response. This is just a hint that the gateway + * sends to the LLM provider. The provider may use a different limit. + * @default 0 (unlimited) + */ + @MCAttribute + public void setMaxOutputTokens(int maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + } + + public int getMaxInputTokens() { + return maxInputTokens; + } + + /** + * @param maxInputTokens Maximum number of tokens that a request can use. + * @description Restricts token usage for the input. The size of the input is estimated by gateway based on the request size. + * Actual token usage may be deviate from this value. + */ + @MCAttribute + public void setMaxInputTokens(int maxInputTokens) { + this.maxInputTokens = maxInputTokens; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java index 2824842f67..8bca1b36c4 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java @@ -58,7 +58,9 @@ public class LLMGatewayInterceptor extends AbstractInterceptor { private String apiKey; - private Policies policies = new Policies(); + private Policies policies = new NullPolicies(); + + private SystemPrompt systemPrompt; private AiApiStore store; @@ -66,6 +68,7 @@ public class LLMGatewayInterceptor extends AbstractInterceptor { public void init() { super.init(); errorCreator = provider.getErrorCreator(); + policies.init(errorCreator); if (store != null) store.init(router); @@ -126,36 +129,18 @@ public Outcome handleRequest(Exchange exc) { log.debug("Requested model: {}", aiReq.getModel()); - var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); - - if (policies.getMaxOutputTokens() > 0) { - if (requestedMaxOutputTokens <= 0) { - log.info("No max. output requested. Setting limit to {}.", policies.getMaxOutputTokens()); - aiReq.setMaxOutputTokens(policies.getMaxOutputTokens()); - } else if (requestedMaxOutputTokens > policies.getMaxOutputTokens()) { - log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.", requestedMaxOutputTokens, policies.getMaxOutputTokens()); - aiReq.setMaxOutputTokens(policies.getMaxOutputTokens()); - } + var outcome = policies.handleRequest(aiReq,exc); + if (outcome != CONTINUE) { + return outcome; } - if (policies.getMaxInputTokens() != 0) { - if (inputTokens > policies.getMaxInputTokens()) { - log.info("Input tokens {} exceed the limit of {}.", inputTokens, policies.getMaxInputTokens()); - exc.setResponse(errorCreator.inputTokensExceeded(policies.getMaxInputTokens(), inputTokens)); - return RETURN; + if (systemPrompt != null) { + outcome = systemPrompt.handleRequest(aiReq,exc); + if (outcome != CONTINUE) { + return outcome; } } - if (policies.getModels() != null) { - var model = aiReq.getModel(); - if (!policies.getModels().contains(model)) { - exc.setResponse(errorCreator.modelNotAllowed(model, policies.getModels())); - return RETURN; - } - } - - log.debug("Agent provides the tools: {}", aiReq.getTools()); - setJsonBody(exc.getRequest(), aiReq.getJson()); return CONTINUE; } @@ -238,4 +223,13 @@ public Policies getPolicies() { public void setPolicies(Policies policies) { this.policies = policies; } + + public SystemPrompt getSystemPrompt() { + return systemPrompt; + } + + @MCChildElement + public void setSystemPrompt(SystemPrompt systemPrompt) { + this.systemPrompt = systemPrompt; + } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/NullPolicies.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/NullPolicies.java new file mode 100644 index 0000000000..8acd6df555 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/NullPolicies.java @@ -0,0 +1,31 @@ +package com.predic8.membrane.core.interceptor.llmgateway; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; + +import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; + +public class NullPolicies implements Policies { + + @Override + public Outcome handleRequest(LLMRequest aiReq, Exchange exc) { + return CONTINUE; + } + + @Override + public void init(LLMErrorCreator errorCreator) { + + } + + @Override + public int getMaxOutputTokens() { + return 0; + } + + @Override + public void setMaxOutputTokens(int maxOutputTokens) { + } +} + diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java index cfbf960e10..62419f0ed2 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java @@ -1,76 +1,17 @@ -/* Copyright 2026 predic8 GmbH, www.predic8.com - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - package com.predic8.membrane.core.interceptor.llmgateway; -import com.predic8.membrane.annot.MCAttribute; -import com.predic8.membrane.annot.MCElement; - -import java.util.List; - -/** - * LLM Gateway policies for token usage and model restrictions. - */ -@MCElement(name = "policies", topLevel = false, id="llm-gateway-policies") -public class Policies { - - private List models; - private int maxOutputTokens; - private int maxInputTokens; - - public List getModels() { - return models; - } - - /** - * @param models List of models that can be used by the gateway. - * @desciption Restricts the models that can be used by the gateway. - * @default null (no restriction) - */ - @MCAttribute - public void setModels(List models) { - this.models = models; - } - +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; - public int getMaxOutputTokens() { - return maxOutputTokens; - } +public interface Policies { - /** - * @param maxOutputTokens Maximum number of tokens the LLM should use to generate a response. - * @description Maximum number of tokens the LLM should use to generate a response. This is just a hint that the gateway - * sends to the LLM provider. The provider may use a different limit. - * @default 0 (unlimited) - */ - @MCAttribute - public void setMaxOutputTokens(int maxOutputTokens) { - this.maxOutputTokens = maxOutputTokens; - } + Outcome handleRequest(LLMRequest aiReq, Exchange exc); - public int getMaxInputTokens() { - return maxInputTokens; - } + void init(LLMErrorCreator errorCreator); - /** - * @param maxInputTokens Maximum number of tokens that a request can use. - * @description Restricts token usage for the input. The size of the input is estimated by gateway based on the request size. - * Actual token usage may be deviate from this value. - */ - @MCAttribute - public void setMaxInputTokens(int maxInputTokens) { - this.maxInputTokens = maxInputTokens; - } + int getMaxOutputTokens(); + void setMaxOutputTokens(int maxOutputTokens); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java new file mode 100644 index 0000000000..c301d46c34 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java @@ -0,0 +1,70 @@ +package com.predic8.membrane.core.interceptor.llmgateway; + +import com.predic8.membrane.annot.MCAttribute; +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; + +/** + * @description When used with older chat completions API the instruction is converted to a system message like: + * "system": "You are a helpful assistant." + */ +@MCElement(name = "systemPrompt") +public class SystemPrompt { + + private static final Logger log = LoggerFactory.getLogger(SystemPrompt.class); + public static final String INSTRUCTIONS = "instructions"; + + enum Action { + REJECT, REMOVE, OVERWRITE, APPEND, PREPEND + } + + private Action action; + private String content = ""; + + public Outcome handleRequest(LLMRequest aiReq, Exchange exc) { + var instructions = aiReq.getSystemPrompt() == null ? "" : aiReq.getSystemPrompt(); + switch (action) { + case OVERWRITE -> { + log.debug("Overwriting instructions: {}", content); + aiReq.setSystemPrompt(content); + } + case PREPEND -> { + log.debug("Prepending instructions: {}", content); + aiReq.setSystemPrompt( content + "\n" + instructions); + } + case APPEND -> { + log.debug("Appending instructions: {}", content); + aiReq.setSystemPrompt(instructions + "\n" + content); + } + case REMOVE -> { + log.info("Removing instructions: {}", instructions); + aiReq.removeSystemPrompt(); + } + } + return CONTINUE; + } + + public Action getAction() { + return action; + } + + @MCAttribute + public void setAction(Action action) { + this.action = action; + } + + public String getContent() { + return content; + } + + @MCAttribute + public void setContent(String content) { + this.content = content; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java index f5955d6acb..54df557594 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java @@ -70,12 +70,11 @@ public String getApiKey() { return null; } - int index = ah.indexOf(BEARER_PREFIX); - if (index < 0) { + if (!ah.regionMatches(true, 0, BEARER_PREFIX, 0, BEARER_PREFIX.length())) { return null; } - var token = ah.substring(index + BEARER_PREFIX.length()).trim(); + var token = ah.substring(BEARER_PREFIX.length()).trim(); return token.isEmpty() ? null : token; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java index 371115e911..64dee19dad 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java @@ -40,4 +40,12 @@ public interface LLMRequest { List getTools(); + String getSystemPrompt(); + + boolean isChatCompletion(); + + void setSystemPrompt(String systemPrompt); + + void removeSystemPrompt(); } + diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java index 4ecbf9065a..951fb99edc 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java @@ -14,6 +14,7 @@ package com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.AbstractOpenAiLLMRequest; @@ -53,6 +54,65 @@ public List getTools() { .toList(); } + /** + * Returns the content of the first {@code "role": "system"} message, + * or an empty string if none is present. + */ + @Override + public String getSystemPrompt() { + for (var message : json.path("messages")) { + if ("system".equals(message.path("role").asText())) { + return message.path("content").asText(""); + } + } + return ""; + } + + /** + * Sets the system prompt in the {@code "messages"} array. + * If a system message already exists its {@code "content"} is updated in place; + * otherwise a new {@code {"role":"system","content":"..."}} entry is prepended. + * + *

Chat Completions API wire format: + *

{@code
+     * { "messages": [{"role": "system", "content": "You are a helpful assistant."}, ...] }
+     * }
+ */ + @Override + public void setSystemPrompt(String systemPrompt) { + var messages = json.withArray("messages"); + for (var message : messages) { + if ("system".equals(message.path("role").asText())) { + ((ObjectNode) message).put("content", systemPrompt); + return; + } + } + // No system message found — prepend one + var systemMessage = json.objectNode(); + systemMessage.put("role", "system"); + systemMessage.put("content", systemPrompt); + messages.insert(0, systemMessage); + } + + /** + * Removes all {@code "role": "system"} messages from the {@code "messages"} array. + * Has no effect if no system message is present. + */ + @Override + public void removeSystemPrompt() { + var messages = json.withArray("messages"); + for (int i = messages.size() - 1; i >= 0; i--) { + if ("system".equals(messages.get(i).path("role").asText())) { + messages.remove(i); + } + } + } + + @Override + public boolean isChatCompletion() { + return true; + } + @Override public long getRequestedMaxOutputTokens() { return json.path("max_completion_tokens").asLong(0); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java index fa5279afe4..e99d06a979 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java @@ -20,6 +20,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * system field for system prompt + */ public class ClaudeLLMRequest extends AbstractLLMRequest { private static final Logger log = LoggerFactory.getLogger(ClaudeLLMRequest.class); @@ -80,6 +83,20 @@ public long estimateInputTokens() { return tokens; } + /** + * Returns the system prompt from the top-level {@code "system"} field, + * or an empty string if no system prompt is set. + */ + @Override + public String getSystemPrompt() { + return json.path("system").asText(""); + } + + @Override + public boolean isChatCompletion() { + return false; + } + private boolean isThinking() { var thinking = json.path("thinking"); return thinking.isObject() && "enabled".equals(thinking.path("type").asText()); @@ -105,4 +122,25 @@ public void setApiKey(String apiKey) { exchange.getRequest().getHeader().removeFields(X_API_KEY); exchange.getRequest().getHeader().add(X_API_KEY, apiKey); } + + /** + * Sets the top-level {@code "system"} field to {@code systemPrompt}. + * Replaces any existing system prompt. + * + *

Claude API wire format: + *

{@code { "system": "You are a helpful assistant.", "messages": [...] }}
+ */ + @Override + public void setSystemPrompt(String systemPrompt) { + json.put("system", systemPrompt); + } + + /** + * Removes the top-level {@code "system"} field entirely. + * Has no effect if no system prompt is present. + */ + @Override + public void removeSystemPrompt() { + json.remove("system"); + } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java index bd60b10617..ad7d4328a3 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java @@ -97,6 +97,51 @@ public long estimateInputTokens() { return Math.max(1, Math.round(chars / 4.0 * 1.15)); } + /** + * Returns the text of the first part inside {@code systemInstruction}, + * or an empty string if no system prompt is set. + * + *

Gemini API wire format: + *

{@code
+     * { "systemInstruction": { "parts": [{ "text": "You are a helpful assistant." }] } }
+     * }
+ */ + @Override + public String getSystemPrompt() { + for (var part : json.path("systemInstruction").path("parts")) { + if (part.path("text").isTextual()) { + return part.path("text").asText(""); + } + } + return ""; + } + + /** + * Sets {@code systemInstruction} to a single text part carrying {@code systemPrompt}. + * Replaces any existing system instruction. + */ + @Override + public void setSystemPrompt(String systemPrompt) { + json.putObject("systemInstruction") + .putArray("parts") + .addObject() + .put("text", systemPrompt); + } + + /** + * Removes the {@code systemInstruction} field entirely. + * Has no effect if no system instruction is present. + */ + @Override + public void removeSystemPrompt() { + json.remove("systemInstruction"); + } + + @Override + public boolean isChatCompletion() { + return exchange.getRequest().getUri().contains("/chat/completions"); + } + private long countText(JsonNode node) { if (node == null || node.isMissingNode() || node.isNull()) { return 0; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java index 3caa187c88..ef7867f1d0 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java @@ -37,6 +37,37 @@ public List getTools() { .toList(); } + @Override + public String getSystemPrompt() { + return json.path("instructions").asText(); + } + + @Override + public boolean isChatCompletion() { + return false; + } + + /** + * Sets the {@code "instructions"} field, which is the system prompt in the + * OpenAI Responses API. Replaces any existing value. + * + *

OpenAI Responses API wire format: + *

{@code { "instructions": "You are a helpful assistant.", "input": "..." }}
+ */ + @Override + public void setSystemPrompt(String systemPrompt) { + json.put("instructions", systemPrompt); + } + + /** + * Removes the {@code "instructions"} field entirely. + * Has no effect if no system prompt is present. + */ + @Override + public void removeSystemPrompt() { + json.remove("instructions"); + } + @Override public long getRequestedMaxOutputTokens() { if (json.has("max_output_tokens")) diff --git a/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java b/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java new file mode 100644 index 0000000000..747b24e540 --- /dev/null +++ b/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java @@ -0,0 +1,55 @@ +package com.predic8.membrane.core.interceptor.llmgateway.provider; + +import com.predic8.membrane.core.exchange.Exchange; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.net.URISyntaxException; + +import static com.predic8.membrane.core.http.Request.post; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class AbstractLLMRequestTest { + + @ParameterizedTest + @ValueSource(strings = { + "Bearer test-api-key", + "bearer test-api-key", + "BEARER test-api-key", + "bEaReR test-api-key" + }) + void getApiKeyAcceptsBearerCaseInsensitive(String authorization) throws URISyntaxException { + var request = new TestLLMRequest(post("http://localhost/chat/completions") + .header("Authorization", authorization) + .json("{}") + .buildExchange()); + + assertEquals("test-api-key", request.getApiKey()); + } + + private static class TestLLMRequest extends AbstractLLMRequest { + + TestLLMRequest(Exchange exchange) { + super(exchange); + } + + @Override + public long getRequestedMaxOutputTokens() { + return -1; + } + + @Override + public void setMaxOutputTokens(int maxOutputTokens) { + } + + @Override + public long estimateInputTokens() { + return 0; + } + + @Override + public String getSystemPrompt() { + return null; + } + } +} \ No newline at end of file From 60a66e6446ac6330c7ab4cca51e38c5bd2522f50 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 27 May 2026 16:06:35 +0200 Subject: [PATCH 38/43] feat: extend `SystemPrompt` with new actions and update tests - Added `setSystemPrompt`, `removeSystemPrompt`, and `isChatCompletion` methods for enhanced prompt management. - Refactored `SystemPrompt.Action` to remove unused `REJECT` action. - Updated `AbstractLLMRequestTest` to validate new `SystemPrompt` behaviors. --- .../core/interceptor/llmgateway/SystemPrompt.java | 5 ++--- .../provider/AbstractLLMRequestTest.java | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java index c301d46c34..a63c1535c0 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java @@ -18,10 +18,9 @@ public class SystemPrompt { private static final Logger log = LoggerFactory.getLogger(SystemPrompt.class); - public static final String INSTRUCTIONS = "instructions"; - enum Action { - REJECT, REMOVE, OVERWRITE, APPEND, PREPEND + public enum Action { + REMOVE, OVERWRITE, APPEND, PREPEND } private Action action; diff --git a/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java b/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java index 747b24e540..af841ffed8 100644 --- a/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java +++ b/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java @@ -51,5 +51,20 @@ public long estimateInputTokens() { public String getSystemPrompt() { return null; } + + @Override + public boolean isChatCompletion() { + return false; + } + + @Override + public void setSystemPrompt(String systemPrompt) { + + } + + @Override + public void removeSystemPrompt() { + + } } } \ No newline at end of file From 9c8f24cced6904a47bdd1aceb6ed3329a2f45ce9 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 May 2026 08:53:58 +0200 Subject: [PATCH 39/43] feat: improve policy validation and consolidate system prompt handling across providers - Added validation for token limits in `DefaultPolicies` and `AiApiUser` classes. - Refactored system prompt methods (`setSystemPrompts`, `getRequestedMaxOutputTokens`) for consistency across LLM providers. - Standardized concatenation logic for multi-prompt handling in providers (Claude, OpenAI, Google, Chat Completions). - Enhanced error handling with `ConfigurationException` in token-related attributes. --- .../llmgateway/DefaultPolicies.java | 13 +++++-- .../interceptor/llmgateway/SystemPrompt.java | 10 +++--- .../llmgateway/provider/LLMRequest.java | 2 +- .../ChatCompletionsProvider.java | 2 +- .../ChatCompletionsRequest.java | 36 ++++++++++--------- .../provider/claude/ClaudeLLMRequest.java | 11 +++--- .../provider/google/GoogleLLMRequest.java | 16 ++++++--- .../openai/AbstractOpenAiLLMRequest.java | 3 +- .../openai/OpenAiLLMResponsesRequest.java | 11 +++--- .../llmgateway/store/AiApiUser.java | 10 ++++-- 10 files changed, 69 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java index 184b591136..caa8ae52aa 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java @@ -20,6 +20,7 @@ import com.predic8.membrane.core.interceptor.Outcome; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.util.ConfigurationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,10 +32,10 @@ /** * @description LLM Gateway policies for token usage and model restrictions. */ -@MCElement(name = "policies", id="llm-gateway-policies") +@MCElement(name = "policies", id = "llm-gateway-policies") public class DefaultPolicies implements Policies { - private static final Logger log = LoggerFactory.getLogger(LLMGatewayInterceptor.class); + private static final Logger log = LoggerFactory.getLogger(DefaultPolicies.class); private LLMErrorCreator errorCreator; @@ -45,7 +46,7 @@ public class DefaultPolicies implements Policies { public void init(LLMErrorCreator errorCreator) { this.errorCreator = errorCreator; } - + public Outcome handleRequest(LLMRequest aiReq, Exchange exc) { var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); @@ -107,6 +108,9 @@ public int getMaxOutputTokens() { */ @MCAttribute public void setMaxOutputTokens(int maxOutputTokens) { + if (maxOutputTokens < 0) { + throw new IllegalArgumentException("maxOutputTokens must be >= 0"); + } this.maxOutputTokens = maxOutputTokens; } @@ -121,6 +125,9 @@ public int getMaxInputTokens() { */ @MCAttribute public void setMaxInputTokens(int maxInputTokens) { + if (maxInputTokens < 0) { + throw new ConfigurationException("maxInputTokens must be >= 0"); + } this.maxInputTokens = maxInputTokens; } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java index a63c1535c0..e969cec75f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java @@ -8,6 +8,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.List; + import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; /** @@ -23,7 +25,7 @@ public enum Action { REMOVE, OVERWRITE, APPEND, PREPEND } - private Action action; + private Action action = Action.OVERWRITE; private String content = ""; public Outcome handleRequest(LLMRequest aiReq, Exchange exc) { @@ -31,15 +33,15 @@ public Outcome handleRequest(LLMRequest aiReq, Exchange exc) { switch (action) { case OVERWRITE -> { log.debug("Overwriting instructions: {}", content); - aiReq.setSystemPrompt(content); + aiReq.setSystemPrompts(List.of(content)); } case PREPEND -> { log.debug("Prepending instructions: {}", content); - aiReq.setSystemPrompt( content + "\n" + instructions); + aiReq.setSystemPrompts(List.of(content, instructions)); } case APPEND -> { log.debug("Appending instructions: {}", content); - aiReq.setSystemPrompt(instructions + "\n" + content); + aiReq.setSystemPrompts(List.of(instructions, content)); } case REMOVE -> { log.info("Removing instructions: {}", instructions); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java index 64dee19dad..31a65919c9 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java @@ -44,7 +44,7 @@ public interface LLMRequest { boolean isChatCompletion(); - void setSystemPrompt(String systemPrompt); + void setSystemPrompts(List prompts); void removeSystemPrompt(); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java index 1ac5be3699..e2089bbcc0 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java @@ -57,6 +57,6 @@ public LLMResponse getLLMResponse(Exchange request, Consumer postPr @Override public LLMErrorCreator getErrorCreator() { - return null; + return new ChatCompletionsErrorCreator(); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java index 951fb99edc..e9f11b2db9 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java @@ -14,7 +14,6 @@ package com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions; -import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.AbstractOpenAiLLMRequest; @@ -69,29 +68,29 @@ public String getSystemPrompt() { } /** - * Sets the system prompt in the {@code "messages"} array. - * If a system message already exists its {@code "content"} is updated in place; - * otherwise a new {@code {"role":"system","content":"..."}} entry is prepended. + * Replaces all system messages with one separate {@code {"role":"system","content":"..."}} message + * per prompt, prepended to the messages array in list order. * *

Chat Completions API wire format: *

{@code
-     * { "messages": [{"role": "system", "content": "You are a helpful assistant."}, ...] }
+     * { "messages": [
+     *     {"role": "system", "content": "prompt 1"},
+     *     {"role": "system", "content": "prompt 2"},
+     *     ...user messages...
+     * ]}
      * }
*/ @Override - public void setSystemPrompt(String systemPrompt) { + public void setSystemPrompts(List prompts) { + removeSystemPrompt(); var messages = json.withArray("messages"); - for (var message : messages) { - if ("system".equals(message.path("role").asText())) { - ((ObjectNode) message).put("content", systemPrompt); - return; - } + // Insert in reverse so that prompts[0] ends up at index 0 + for (int i = prompts.size() - 1; i >= 0; i--) { + var systemMessage = json.objectNode(); + systemMessage.put("role", "system"); + systemMessage.put("content", prompts.get(i)); + messages.insert(0, systemMessage); } - // No system message found — prepend one - var systemMessage = json.objectNode(); - systemMessage.put("role", "system"); - systemMessage.put("content", systemPrompt); - messages.insert(0, systemMessage); } /** @@ -115,7 +114,10 @@ public boolean isChatCompletion() { @Override public long getRequestedMaxOutputTokens() { - return json.path("max_completion_tokens").asLong(0); + // Prefer max_completion_tokens (modern OpenAI/o1+), fall back to max_tokens (legacy / all other providers) + long v = json.path("max_completion_tokens").asLong(0); + if (v > 0) return v; + return json.path("max_tokens").asLong(0); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java index e99d06a979..40eb7261ec 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java @@ -20,6 +20,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.List; + /** * system field for system prompt */ @@ -124,15 +126,14 @@ public void setApiKey(String apiKey) { } /** - * Sets the top-level {@code "system"} field to {@code systemPrompt}. - * Replaces any existing system prompt. + * Concatenates all prompts (newline-separated) into the top-level {@code "system"} field. * *

Claude API wire format: - *

{@code { "system": "You are a helpful assistant.", "messages": [...] }}
+ *
{@code { "system": "prompt 1\nprompt 2", "messages": [...] }}
*/ @Override - public void setSystemPrompt(String systemPrompt) { - json.put("system", systemPrompt); + public void setSystemPrompts(List prompts) { + json.put("system", String.join("\n", prompts)); } /** diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java index ad7d4328a3..adb62a63f4 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java @@ -19,6 +19,8 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; +import java.util.List; + public class GoogleLLMRequest extends AbstractLLMRequest { /** @@ -117,15 +119,18 @@ public String getSystemPrompt() { } /** - * Sets {@code systemInstruction} to a single text part carrying {@code systemPrompt}. - * Replaces any existing system instruction. + * Concatenates all prompts (newline-separated) into a single text part under + * {@code systemInstruction}. Replaces any existing system instruction. + * + *

Gemini API wire format: + *

{@code { "systemInstruction": { "parts": [{ "text": "prompt 1\nprompt 2" }] } }}
*/ @Override - public void setSystemPrompt(String systemPrompt) { + public void setSystemPrompts(List prompts) { json.putObject("systemInstruction") .putArray("parts") .addObject() - .put("text", systemPrompt); + .put("text", String.join("\n", prompts)); } /** @@ -139,7 +144,8 @@ public void removeSystemPrompt() { @Override public boolean isChatCompletion() { - return exchange.getRequest().getUri().contains("/chat/completions"); + // Gemini uses its own generateContent API, not Chat Completions + return false; } private long countText(JsonNode node) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java index b49e7440fc..0686831bba 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java @@ -31,8 +31,9 @@ public long estimateInputTokens() { chars += estimateChatCompletitions(); - // system instructions + // system instructions: "system" (chat completions) or "instructions" (responses API) chars += countText(json.path("system")); + chars += countText(json.path("instructions")); // tools/functions contribute significantly chars += countJsonSize(json.path("tools")); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java index ef7867f1d0..7825772d92 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java @@ -39,7 +39,7 @@ public List getTools() { @Override public String getSystemPrompt() { - return json.path("instructions").asText(); + return json.path("instructions").asText(""); } @Override @@ -48,15 +48,14 @@ public boolean isChatCompletion() { } /** - * Sets the {@code "instructions"} field, which is the system prompt in the - * OpenAI Responses API. Replaces any existing value. + * Concatenates all prompts (newline-separated) into the {@code "instructions"} field. * *

OpenAI Responses API wire format: - *

{@code { "instructions": "You are a helpful assistant.", "input": "..." }}
+ *
{@code { "instructions": "prompt 1\nprompt 2", "input": "..." }}
*/ @Override - public void setSystemPrompt(String systemPrompt) { - json.put("instructions", systemPrompt); + public void setSystemPrompts(List prompts) { + json.put("instructions", String.join("\n", prompts)); } /** diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiUser.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiUser.java index da8b792680..d2a5c9b018 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiUser.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/AiApiUser.java @@ -16,12 +16,13 @@ import com.predic8.membrane.annot.MCAttribute; import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.util.ConfigurationException; import java.util.concurrent.atomic.AtomicLong; import static java.lang.Long.MAX_VALUE; -@MCElement(name = "users", component = false, id="ai-api-users") +@MCElement(name = "users", component = false, id = "ai-api-users") public class AiApiUser { private String name; @@ -33,6 +34,7 @@ public class AiApiUser { /** * Updates the store with the number of tokens used in this call + * * @param usage The number of tokens used */ public void addTokensUsedInPeriod(Usage usage) { @@ -45,6 +47,7 @@ public void resetTokensUsedInPeriod() { /** * Checks if the user has enough tokens to make the request. + * * @param tokensNeededForRequest The number of tokens that the user needs to make the request * @return The estimated number of tokens that the user has left after this request */ @@ -59,8 +62,8 @@ public String getName() { } /** - * @description Name of the API user, group or cost center. * @param name of the user + * @description Name of the API user, group or cost center. */ @MCAttribute() public void setName(String name) { @@ -93,6 +96,9 @@ public long getTokens() { */ @MCAttribute public void setTokens(long tokens) { + if (tokens < 0) { + throw new ConfigurationException("tokens must be >= 0"); + } this.tokens = tokens; } From 9f615abf7d198d52036f601df24be74fd16459fa Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 May 2026 15:05:31 +0200 Subject: [PATCH 40/43] feat: add AbstractLLMProvider, extend request handling, and improve multipart utility - Introduced `AbstractLLMProvider` to streamline and centralize LLM request type handling. - Added `AbstractLLMRequest` and extended support for specialized requests (e.g., `AudioRequest`, `ImagesRequest`, `FilesRequest`, `OrganizationRequest`). - Implemented `MultipartUtil` to simplify handling of multipart HTTP messages. - Updated providers (OpenAI, Claude, Chat Completions) to align with new abstractions and support for `IOException`. - Enhanced test coverage with `MultipartUtilTest` and additional provider-specific unit tests. --- .../predic8/membrane/core/http/Header.java | 63 +++-- .../predic8/membrane/core/http/Message.java | 7 + .../llmgateway/DefaultPolicies.java | 48 ++-- .../llmgateway/LLMGatewayInterceptor.java | 66 ++--- .../interceptor/llmgateway/NullPolicies.java | 4 +- .../core/interceptor/llmgateway/Policies.java | 4 +- .../interceptor/llmgateway/SystemPrompt.java | 14 +- .../provider/AbstractLLMProvider.java | 35 +++ .../provider/AbstractLLMRequest.java | 58 +---- .../llmgateway/provider/JSONRequest.java | 8 + .../llmgateway/provider/LLMProvider.java | 3 +- .../llmgateway/provider/LLMRequest.java | 27 -- .../provider/ModelInputRequest.java | 30 +++ .../ChatCompletionsProvider.java | 3 +- .../ChatCompletionsRequest.java | 8 +- .../provider/claude/ClaudeLLMRequest.java | 13 +- .../provider/claude/ClaudeProvider.java | 3 +- .../provider/google/GoogleLLMRequest.java | 14 +- .../provider/google/GoogleProvider.java | 3 +- .../openai/AbstractOpenAiLLMRequest.java | 8 +- .../provider/openai/AudioRequest.java | 12 + .../provider/openai/FilesRequest.java | 12 + .../provider/openai/ImagesRequest.java | 11 + .../openai/OpenAIChatCompletionsRequest.java | 4 +- .../provider/openai/OpenAIProvider.java | 18 +- .../openai/OpenAiLLMResponsesRequest.java | 8 +- .../provider/openai/OrganizationRequest.java | 11 + .../llmgateway/store/SimpleAiApiStore.java | 2 +- .../core/interceptor/log/LogInterceptor.java | 4 + .../core/multipart/MultipartUtil.java | 98 +++++++ .../predic8/membrane/core/multipart/Part.java | 243 ++++++++---------- .../core/multipart/XOPReconstitutor.java | 102 +++++--- .../membrane/core/http/HeaderTest.java | 61 ++++- ...ava => AbstractModelInputRequestTest.java} | 17 +- .../core/multipart/MultipartUtilTest.java | 229 +++++++++++++++++ .../BasicClaudeLLMGatewayTutorialTest.java | 2 + .../SharingApiKeysOpenAiTutorialTest.java | 1 + 37 files changed, 858 insertions(+), 396 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMProvider.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/JSONRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/ModelInputRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AudioRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/FilesRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/ImagesRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OrganizationRequest.java create mode 100644 core/src/main/java/com/predic8/membrane/core/multipart/MultipartUtil.java rename core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/{AbstractLLMRequestTest.java => AbstractModelInputRequestTest.java} (79%) create mode 100644 core/src/test/java/com/predic8/membrane/core/multipart/MultipartUtilTest.java diff --git a/core/src/main/java/com/predic8/membrane/core/http/Header.java b/core/src/main/java/com/predic8/membrane/core/http/Header.java index 68844f5a18..cc435d43c4 100644 --- a/core/src/main/java/com/predic8/membrane/core/http/Header.java +++ b/core/src/main/java/com/predic8/membrane/core/http/Header.java @@ -15,29 +15,41 @@ package com.predic8.membrane.core.http; import com.predic8.membrane.annot.Constants; -import com.predic8.membrane.core.http.cookie.*; -import com.predic8.membrane.core.util.*; -import jakarta.mail.internet.*; -import org.jetbrains.annotations.*; -import org.slf4j.*; - -import java.io.*; -import java.security.*; -import java.util.*; +import com.predic8.membrane.core.http.cookie.Cookies; +import com.predic8.membrane.core.http.cookie.MimeHeaders; +import com.predic8.membrane.core.http.cookie.ServerCookie; +import com.predic8.membrane.core.util.EndOfStreamException; +import com.predic8.membrane.core.util.HttpUtil; +import jakarta.mail.internet.ContentType; +import jakarta.mail.internet.ParseException; +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.security.InvalidParameterException; import java.util.ArrayList; -import java.util.function.*; -import java.util.regex.*; -import java.util.stream.*; - -import static com.predic8.membrane.core.http.MimeType.*; -import static com.predic8.membrane.core.util.HttpUtil.*; -import static java.nio.charset.StandardCharsets.*; -import static java.util.Arrays.*; -import static java.util.Collections.*; +import java.util.List; +import java.util.Set; +import java.util.function.Predicate; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Stream; + +import static com.predic8.membrane.core.http.MimeType.isBinary; +import static com.predic8.membrane.core.util.HttpUtil.readLine; +import static java.nio.charset.StandardCharsets.ISO_8859_1; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Arrays.stream; +import static java.util.Collections.unmodifiableList; import static java.util.Locale.ROOT; -import static java.util.regex.Pattern.*; -import static java.util.stream.Collectors.*; -import static org.apache.commons.codec.binary.Base64.*; +import static java.util.regex.Pattern.CASE_INSENSITIVE; +import static java.util.regex.Pattern.compile; +import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.toSet; +import static org.apache.commons.codec.binary.Base64.encodeBase64; /** * The headers of an HTTP message. @@ -331,6 +343,15 @@ public String getContentType() { return getFirstValue(CONTENT_TYPE); } + /** + * Returns {@code true} if the {@code Content-Type} header starts with {@code multipart/} + * (e.g. {@code multipart/form-data}, {@code multipart/related}, {@code multipart/mixed}). + */ + public boolean isMultipart() { + String ct = getContentType(); + return ct != null && ct.regionMatches(true, 0, "multipart/", 0, 10); + } + public String getUserAgent() { return getFirstValue(USER_AGENT); } diff --git a/core/src/main/java/com/predic8/membrane/core/http/Message.java b/core/src/main/java/com/predic8/membrane/core/http/Message.java index f38024aae0..b96755086f 100644 --- a/core/src/main/java/com/predic8/membrane/core/http/Message.java +++ b/core/src/main/java/com/predic8/membrane/core/http/Message.java @@ -338,6 +338,13 @@ public boolean isImage() { return MimeType.isImage(getHeader().getContentType()); } + /** + * @return true if the message has a media type of image/*, audio/*, video/*, octect-stream, or application/octet-stream + */ + public boolean isBinary() { + return MimeType.isBinary(getHeader().getContentType()); + } + public boolean isXML() { return MimeType.isXML(getHeader().getContentType()); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java index caa8ae52aa..7c6f9612c4 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java @@ -19,7 +19,8 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.Outcome; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; -import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.ModelInputRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.OrganizationRequest; import com.predic8.membrane.core.util.ConfigurationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,18 +48,44 @@ public void init(LLMErrorCreator errorCreator) { this.errorCreator = errorCreator; } - public Outcome handleRequest(LLMRequest aiReq, Exchange exc) { + public Outcome handleRequest(ModelInputRequest mir, Exchange exc) { - var requestedMaxOutputTokens = aiReq.getRequestedMaxOutputTokens(); - var inputTokens = aiReq.estimateInputTokens(); + if (mir instanceof OrganizationRequest) { + return CONTINUE; + } + + var outcome = checkTokenLimits(mir, exc); + if (outcome != CONTINUE) { + return outcome; + } + outcome = checkModel(mir, exc); + if (outcome != CONTINUE) { + return outcome; + } + return CONTINUE; + } + + public Outcome checkModel(ModelInputRequest mir, Exchange exc) { + var model = mir.getModel(); + if (models != null && !models.contains(model)) { + exc.setResponse(errorCreator.modelNotAllowed(model, models)); + return RETURN; + } + return CONTINUE; + } + + public Outcome checkTokenLimits(ModelInputRequest mir, Exchange exc) { + + var requestedMaxOutputTokens = mir.getRequestedMaxOutputTokens(); + var inputTokens = mir.estimateInputTokens(); if (maxOutputTokens > 0) { if (requestedMaxOutputTokens <= 0) { log.info("No max. output requested. Setting limit to {}.", maxOutputTokens); - aiReq.setMaxOutputTokens(maxOutputTokens); + mir.setMaxOutputTokens(maxOutputTokens); } else if (requestedMaxOutputTokens > maxOutputTokens) { log.info("Requested max. output tokens {} exceed the limit. Setting limit to {}.", requestedMaxOutputTokens, maxOutputTokens); - aiReq.setMaxOutputTokens(maxOutputTokens); + mir.setMaxOutputTokens(maxOutputTokens); } } @@ -69,15 +96,6 @@ public Outcome handleRequest(LLMRequest aiReq, Exchange exc) { return RETURN; } } - - if (models != null) { - var model = aiReq.getModel(); - if (!models.contains(model)) { - exc.setResponse(errorCreator.modelNotAllowed(model, models)); - return RETURN; - } - } - return CONTINUE; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java index 8bca1b36c4..fca36d7066 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/LLMGatewayInterceptor.java @@ -23,6 +23,7 @@ import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMProvider; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.ModelInputRequest; import com.predic8.membrane.core.interceptor.llmgateway.store.AiApiStore; import com.predic8.membrane.core.interceptor.llmgateway.store.AiApiUser; import com.predic8.membrane.core.util.ConfigurationException; @@ -31,7 +32,6 @@ import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; import static com.predic8.membrane.core.interceptor.Outcome.RETURN; -import static com.predic8.membrane.core.util.json.JsonUtil.setJsonBody; /* * @description

@@ -81,23 +81,17 @@ public void init() { @Override public Outcome handleRequest(Exchange exc) { - LLMRequest aiReq; + LLMRequest llmReq; try { - aiReq = provider.getLLMRequest(exc); + llmReq = provider.getLLMRequest(exc); } catch (Exception e) { exc.setResponse(errorCreator.invalidRequestError("Error parsing request: " + e.getMessage())); return RETURN; } - if (!exc.getRequest().isPOSTRequest()) { - if (apiKey != null) - aiReq.setApiKey(apiKey); - return CONTINUE; - } - AiApiUser user = null; if (store != null) { - var opt = store.getUser(aiReq.getApiKey()); + var opt = store.getUser(llmReq.getApiKey()); if (opt.isEmpty()) { exc.setResponse(errorCreator.authenticationFailed()); return RETURN; @@ -107,41 +101,53 @@ public Outcome handleRequest(Exchange exc) { exc.setProperty(MEMBRANE_AI_USER, user); } - long inputTokens = aiReq.estimateInputTokens(); - log.debug("Estimated input tokens: {}", inputTokens); - - // Check store limits - if (store != null) { - var effectiveMaxTokens = computeEffectiveMaxOutputTokens(aiReq.getRequestedMaxOutputTokens(), policies.getMaxOutputTokens()); - var remaining = store.checkLimit(user, inputTokens, effectiveMaxTokens); - log.debug("User {} has {} remaining tokens left", user, remaining); - if (remaining <= 0) { - log.info("Token limit exceeded. Remaining: {} input: {} maxOutput: {}", remaining, inputTokens, effectiveMaxTokens); - exc.setResponse(errorCreator.tokenLimitExceeded(inputTokens + effectiveMaxTokens, remaining, store.getRemainingResetTime())); - return RETURN; - } - } - // If APIKey is specified, use that for the LLM. Overwrites keys from the client if (apiKey != null) { - aiReq.setApiKey(apiKey); + llmReq.setApiKey(apiKey); + } + + if (!exc.getRequest().isPOSTRequest()) { + return CONTINUE; } - log.debug("Requested model: {}", aiReq.getModel()); + if (!(llmReq instanceof ModelInputRequest mir)) { + return CONTINUE; + } - var outcome = policies.handleRequest(aiReq,exc); + var outcome = policies.handleRequest(mir, exc); if (outcome != CONTINUE) { return outcome; } if (systemPrompt != null) { - outcome = systemPrompt.handleRequest(aiReq,exc); + outcome = systemPrompt.handleRequest(mir, exc); if (outcome != CONTINUE) { return outcome; } } - setJsonBody(exc.getRequest(), aiReq.getJson()); + // Check store limits + if (checkStoreLimits(exc, mir, user) != CONTINUE) { + return RETURN; + } + + exc.getRequest().setBodyContent(mir.getBody().getContent()); + return CONTINUE; + } + + private Outcome checkStoreLimits(Exchange exc, ModelInputRequest mir, AiApiUser user) { + long inputTokens = mir.estimateInputTokens(); + log.debug("Estimated input tokens: {}", inputTokens); + if (store != null) { + var effectiveMaxTokens = computeEffectiveMaxOutputTokens(mir.getRequestedMaxOutputTokens(), policies.getMaxOutputTokens()); + var remaining = store.checkLimit(user, inputTokens, effectiveMaxTokens); + log.debug("User {} has {} remaining tokens left", user, remaining); + if (remaining <= 0) { + log.info("Token limit exceeded. Remaining: {} input: {} maxOutput: {}", remaining, inputTokens, effectiveMaxTokens); + exc.setResponse(errorCreator.tokenLimitExceeded(inputTokens + effectiveMaxTokens, remaining, store.getRemainingResetTime())); + return RETURN; + } + } return CONTINUE; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/NullPolicies.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/NullPolicies.java index 8acd6df555..a1ba392b3b 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/NullPolicies.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/NullPolicies.java @@ -3,14 +3,14 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.Outcome; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; -import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.ModelInputRequest; import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; public class NullPolicies implements Policies { @Override - public Outcome handleRequest(LLMRequest aiReq, Exchange exc) { + public Outcome handleRequest(ModelInputRequest mir, Exchange exc) { return CONTINUE; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java index 62419f0ed2..fc742e30ce 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/Policies.java @@ -3,11 +3,11 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.Outcome; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; -import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.ModelInputRequest; public interface Policies { - Outcome handleRequest(LLMRequest aiReq, Exchange exc); + Outcome handleRequest(ModelInputRequest mir, Exchange exc); void init(LLMErrorCreator errorCreator); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java index e969cec75f..e2382e0135 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/SystemPrompt.java @@ -4,7 +4,7 @@ import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.Outcome; -import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.ModelInputRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,24 +28,24 @@ public enum Action { private Action action = Action.OVERWRITE; private String content = ""; - public Outcome handleRequest(LLMRequest aiReq, Exchange exc) { - var instructions = aiReq.getSystemPrompt() == null ? "" : aiReq.getSystemPrompt(); + public Outcome handleRequest(ModelInputRequest mir, Exchange exc) { + var instructions = mir.getSystemPrompt() == null ? "" : mir.getSystemPrompt(); switch (action) { case OVERWRITE -> { log.debug("Overwriting instructions: {}", content); - aiReq.setSystemPrompts(List.of(content)); + mir.setSystemPrompts(List.of(content)); } case PREPEND -> { log.debug("Prepending instructions: {}", content); - aiReq.setSystemPrompts(List.of(content, instructions)); + mir.setSystemPrompts(List.of(content, instructions)); } case APPEND -> { log.debug("Appending instructions: {}", content); - aiReq.setSystemPrompts(List.of(instructions, content)); + mir.setSystemPrompts(List.of(instructions, content)); } case REMOVE -> { log.info("Removing instructions: {}", instructions); - aiReq.removeSystemPrompt(); + mir.removeSystemPrompt(); } } return CONTINUE; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMProvider.java new file mode 100644 index 0000000000..6273c2c5db --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMProvider.java @@ -0,0 +1,35 @@ +package com.predic8.membrane.core.interceptor.llmgateway.provider; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.http.ReadingBodyException; +import com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions.ChatCompletionsRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.AudioRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.FilesRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.ImagesRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.OrganizationRequest; + +import java.io.IOException; + +public abstract class AbstractLLMProvider implements LLMProvider { + + @Override + public LLMRequest getLLMRequest(Exchange exchange) throws IOException { + var uri = exchange.getRequest().getUri(); + if (uri.startsWith("/v1/chat/completions")) { + return new ChatCompletionsRequest(exchange); + } + if (uri.startsWith("/v1/files")) { + return new FilesRequest(exchange); + } + if (uri.contains("/v1/images")) { + return new ImagesRequest(exchange); + } + if (uri.contains("/v1/audio")) { + return new AudioRequest(exchange); + } + if (uri.contains("/v1/organization")) { + return new OrganizationRequest(exchange); + } + throw new ReadingBodyException("Unknown request: " + uri); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java index 54df557594..4c7c8fded3 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java @@ -1,60 +1,15 @@ -/* Copyright 2026 predic8 GmbH, www.predic8.com - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - package com.predic8.membrane.core.interceptor.llmgateway.provider; -import com.fasterxml.jackson.databind.node.ArrayNode; -import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.util.json.JsonUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Collections; -import java.util.List; import static com.predic8.membrane.core.http.Header.AUTHORIZATION; -public abstract class AbstractLLMRequest extends AbstractLLMMessage implements LLMRequest { - - private static final Logger log = LoggerFactory.getLogger(AbstractLLMRequest.class); +public class AbstractLLMRequest extends AbstractLLMMessage implements LLMRequest { public static final String BEARER_PREFIX = "Bearer"; - protected ObjectNode json; - - public AbstractLLMRequest(Exchange exchange) { + protected AbstractLLMRequest(Exchange exchange) { super(exchange); - - if (exchange.getRequest().isJSON()) { - json = JsonUtil.getJsonObject(exchange.getRequest()).orElseThrow(() -> new RuntimeException("Cannot parse input as JSON message.")); - } else { - log.info("Request is not JSON:"); - throw new RuntimeException("Request is not JSON."); - } - } - - public List getTools() { - return Collections.emptyList(); - } - - protected ArrayNode getToolsNode() { - if (json == null) - return null; - if (json.path("tools").isArray()) - return (ArrayNode) json.path("tools"); - return null; } @Override @@ -79,13 +34,4 @@ public String getApiKey() { return token.isEmpty() ? null : token; } - @Override - public ObjectNode getJson() { - return json; - } - - @Override - public String getModel() { - return json.path("model").asText(); - } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/JSONRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/JSONRequest.java new file mode 100644 index 0000000000..2859b6fc38 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/JSONRequest.java @@ -0,0 +1,8 @@ +package com.predic8.membrane.core.interceptor.llmgateway.provider; + +import com.fasterxml.jackson.databind.node.ObjectNode; + +interface JSONMessage { + + ObjectNode getJson(); +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMProvider.java index 1fb2fc4eae..457597d70e 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMProvider.java @@ -16,11 +16,12 @@ import com.predic8.membrane.core.exchange.Exchange; +import java.io.IOException; import java.util.function.Consumer; public interface LLMProvider { - LLMRequest getLLMRequest(Exchange request); + LLMRequest getLLMRequest(Exchange request) throws IOException; LLMResponse getLLMResponse(Exchange request, Consumer postProcessor); LLMErrorCreator getErrorCreator(); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java index 31a65919c9..f80230a755 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/LLMRequest.java @@ -14,38 +14,11 @@ package com.predic8.membrane.core.interceptor.llmgateway.provider; -import com.fasterxml.jackson.databind.node.ObjectNode; - -import java.util.List; - public interface LLMRequest { - String getModel(); - String getApiKey(); void setApiKey(String apiKey); - /** - * The max number of tokens that the model is allowed to generate as specified by the client. - * @return The max number of tokens that the model is allowed to generate. -1 if no limit is set. - */ - long getRequestedMaxOutputTokens(); - - void setMaxOutputTokens(int maxOutputTokens); - - long estimateInputTokens(); - - ObjectNode getJson(); - - List getTools(); - - String getSystemPrompt(); - - boolean isChatCompletion(); - - void setSystemPrompts(List prompts); - - void removeSystemPrompt(); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/ModelInputRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/ModelInputRequest.java new file mode 100644 index 0000000000..4a779a140c --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/ModelInputRequest.java @@ -0,0 +1,30 @@ +package com.predic8.membrane.core.interceptor.llmgateway.provider; + +import com.predic8.membrane.core.http.AbstractBody; + +import java.util.List; + +public interface ModelInputRequest extends JSONMessage { + + String getModel(); + + /** + * The max number of tokens that the model is allowed to generate as specified by the client. + * @return The max number of tokens that the model is allowed to generate. -1 if no limit is set. + */ + long getRequestedMaxOutputTokens(); + + void setMaxOutputTokens(int maxOutputTokens); + + long estimateInputTokens(); + + List getTools(); + + String getSystemPrompt(); + + void setSystemPrompts(List prompts); + + void removeSystemPrompt(); + + AbstractBody getBody(); +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java index e2089bbcc0..8f1a7a491a 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsProvider.java @@ -21,6 +21,7 @@ import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; +import java.io.IOException; import java.util.function.Consumer; /** @@ -46,7 +47,7 @@ @MCElement(name = "chatCompletions") public class ChatCompletionsProvider implements LLMProvider { @Override - public LLMRequest getLLMRequest(Exchange request) { + public LLMRequest getLLMRequest(Exchange request) throws IOException { return new ChatCompletionsRequest(request); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java index e9f11b2db9..b50a536131 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/chatcompletions/ChatCompletionsRequest.java @@ -17,13 +17,14 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.AbstractOpenAiLLMRequest; +import java.io.IOException; import java.util.List; import static java.util.Collections.emptyList; public class ChatCompletionsRequest extends AbstractOpenAiLLMRequest { - public ChatCompletionsRequest(Exchange exchange) { + public ChatCompletionsRequest(Exchange exchange) throws IOException { super(exchange); if (json == null) { @@ -107,11 +108,6 @@ public void removeSystemPrompt() { } } - @Override - public boolean isChatCompletion() { - return true; - } - @Override public long getRequestedMaxOutputTokens() { // Prefer max_completion_tokens (modern OpenAI/o1+), fall back to max_tokens (legacy / all other providers) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java index 40eb7261ec..1a0e66c3c3 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeLLMRequest.java @@ -16,22 +16,24 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractModelInputRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.ModelInputRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; import java.util.List; /** * system field for system prompt */ -public class ClaudeLLMRequest extends AbstractLLMRequest { +public class ClaudeLLMRequest extends AbstractModelInputRequest implements ModelInputRequest { private static final Logger log = LoggerFactory.getLogger(ClaudeLLMRequest.class); public static final String X_API_KEY = "x-api-key"; - public ClaudeLLMRequest(Exchange exchange) { + public ClaudeLLMRequest(Exchange exchange) throws IOException { super(exchange); exchange.getRequest().getHeader().setValue( "Accept-Encoding","identity"); @@ -94,11 +96,6 @@ public String getSystemPrompt() { return json.path("system").asText(""); } - @Override - public boolean isChatCompletion() { - return false; - } - private boolean isThinking() { var thinking = json.path("thinking"); return thinking.isObject() && "enabled".equals(thinking.path("type").asText()); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java index a296575058..decc7048b2 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java @@ -21,6 +21,7 @@ import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; +import java.io.IOException; import java.util.function.Consumer; /** @@ -31,7 +32,7 @@ public class ClaudeProvider implements LLMProvider { @Override - public LLMRequest getLLMRequest(Exchange exchange) { + public LLMRequest getLLMRequest(Exchange exchange) throws IOException { return new ClaudeLLMRequest(exchange); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java index adb62a63f4..90f1b1ab36 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleLLMRequest.java @@ -17,18 +17,20 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractModelInputRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.ModelInputRequest; +import java.io.IOException; import java.util.List; -public class GoogleLLMRequest extends AbstractLLMRequest { +public class GoogleLLMRequest extends AbstractModelInputRequest implements ModelInputRequest { /** * x-goog-api-key is correct it is not google */ public static final String X_GOOG_API_KEY = "x-goog-api-key"; - public GoogleLLMRequest(Exchange exchange) { + public GoogleLLMRequest(Exchange exchange) throws IOException { super(exchange); } @@ -142,12 +144,6 @@ public void removeSystemPrompt() { json.remove("systemInstruction"); } - @Override - public boolean isChatCompletion() { - // Gemini uses its own generateContent API, not Chat Completions - return false; - } - private long countText(JsonNode node) { if (node == null || node.isMissingNode() || node.isNull()) { return 0; diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleProvider.java index b1b36ea1df..0654b9b52f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/google/GoogleProvider.java @@ -21,6 +21,7 @@ import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; +import java.io.IOException; import java.util.function.Consumer; /** @@ -31,7 +32,7 @@ public class GoogleProvider implements LLMProvider { @Override - public LLMRequest getLLMRequest(Exchange exchange) { + public LLMRequest getLLMRequest(Exchange exchange) throws IOException { return new GoogleLLMRequest(exchange); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java index 0686831bba..9e75ef5ec5 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AbstractOpenAiLLMRequest.java @@ -16,11 +16,13 @@ import com.fasterxml.jackson.databind.JsonNode; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractModelInputRequest; -public abstract class AbstractOpenAiLLMRequest extends AbstractLLMRequest { +import java.io.IOException; - public AbstractOpenAiLLMRequest(Exchange exchange) { +public abstract class AbstractOpenAiLLMRequest extends AbstractModelInputRequest { + + public AbstractOpenAiLLMRequest(Exchange exchange) throws IOException { super(exchange); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AudioRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AudioRequest.java new file mode 100644 index 0000000000..df028457a3 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AudioRequest.java @@ -0,0 +1,12 @@ +package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractModelInputRequest; + +import java.io.IOException; + +public class AudioRequest extends AbstractModelInputRequest { + public AudioRequest(Exchange exchange) throws IOException { + super(exchange); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/FilesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/FilesRequest.java new file mode 100644 index 0000000000..cc56d7b492 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/FilesRequest.java @@ -0,0 +1,12 @@ +package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; + +public class FilesRequest extends AbstractLLMRequest { + + public FilesRequest(Exchange exchange) { + super(exchange); + } + +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/ImagesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/ImagesRequest.java new file mode 100644 index 0000000000..74ac706d04 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/ImagesRequest.java @@ -0,0 +1,11 @@ +package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; + +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; + +public class ImagesRequest extends AbstractLLMRequest { + + public ImagesRequest(Exchange exchange) { + super(exchange); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIChatCompletionsRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIChatCompletionsRequest.java index 8c6e474398..b26e2794e2 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIChatCompletionsRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIChatCompletionsRequest.java @@ -17,8 +17,10 @@ import com.predic8.membrane.core.exchange.Exchange; import com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions.ChatCompletionsRequest; +import java.io.IOException; + public class OpenAIChatCompletionsRequest extends ChatCompletionsRequest { - public OpenAIChatCompletionsRequest(Exchange exchange) { + public OpenAIChatCompletionsRequest(Exchange exchange) throws IOException { super(exchange); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java index e55d40bd47..9798483ed5 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java @@ -16,13 +16,14 @@ import com.predic8.membrane.annot.MCElement; import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMProvider; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMErrorCreator; -import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMProvider; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMRequest; import com.predic8.membrane.core.interceptor.llmgateway.provider.LLMResponse; import com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions.ChatCompletionsErrorCreator; import com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions.ChatCompletionsResponse; +import java.io.IOException; import java.util.function.Consumer; /** @@ -30,19 +31,21 @@ * Use to configure a LLM gateway to use the OpenAI API */ @MCElement( name="openai") -public class OpenAIProvider implements LLMProvider { +public class OpenAIProvider extends AbstractLLMProvider { @Override - public LLMRequest getLLMRequest(Exchange exchange) { - if (isResponsesApi(exchange)) { + public LLMRequest getLLMRequest(Exchange exchange) throws IOException { + var uri = exchange.getRequest().getUri(); + if (uri.startsWith("/v1/responses")) { return new OpenAiLLMResponsesRequest(exchange); } - return new OpenAIChatCompletionsRequest(exchange); + return super.getLLMRequest(exchange); } @Override public LLMResponse getLLMResponse(Exchange exchange, Consumer postProcessor) { - if (isResponsesApi(exchange)) { + var uri = exchange.getRequest().getUri(); + if (uri.startsWith("/v1/responses")) { return new OpenAiLLMResponsesResponse(exchange,postProcessor); } return new ChatCompletionsResponse(exchange, postProcessor); @@ -53,7 +56,4 @@ public LLMErrorCreator getErrorCreator() { return new ChatCompletionsErrorCreator(); } - static boolean isResponsesApi(Exchange exchange) { - return exchange.getRequest().getUri().startsWith("/v1/responses"); - } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java index 7825772d92..945ede46e4 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAiLLMResponsesRequest.java @@ -16,13 +16,14 @@ import com.predic8.membrane.core.exchange.Exchange; +import java.io.IOException; import java.util.List; import static java.util.Collections.emptyList; public class OpenAiLLMResponsesRequest extends AbstractOpenAiLLMRequest { - public OpenAiLLMResponsesRequest(Exchange exchange) { + public OpenAiLLMResponsesRequest(Exchange exchange) throws IOException { super(exchange); } @@ -42,11 +43,6 @@ public String getSystemPrompt() { return json.path("instructions").asText(""); } - @Override - public boolean isChatCompletion() { - return false; - } - /** * Concatenates all prompts (newline-separated) into the {@code "instructions"} field. * diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OrganizationRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OrganizationRequest.java new file mode 100644 index 0000000000..8d4b5bbb25 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OrganizationRequest.java @@ -0,0 +1,11 @@ +package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; + +import com.predic8.membrane.core.exchange.Exchange; + +import java.io.IOException; + +public class OrganizationRequest extends AbstractOpenAiLLMRequest { + public OrganizationRequest(Exchange exchange) throws IOException { + super(exchange); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/SimpleAiApiStore.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/SimpleAiApiStore.java index 106892c39f..9f7e91b210 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/SimpleAiApiStore.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/store/SimpleAiApiStore.java @@ -52,7 +52,7 @@ public class SimpleAiApiStore implements AiApiStore { @Override public void store(AiApiUser user, Usage usage) { if (logUsage) - log.info("user: {} {}",user.getName(),usage.toString()); + log.info("user: {} {}", user.getName(), usage.toString()); user.addTokensUsedInPeriod(usage); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/log/LogInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/log/LogInterceptor.java index 34fe64dbad..b2f2f10a5c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/log/LogInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/log/LogInterceptor.java @@ -146,6 +146,10 @@ private String dumpHeaderFields(Message msg) { } private static String dumpBody(Message msg) { + if (msg.isBinary()) { + return "[Binary]"; + } + try { return "Body:\n%s\n".formatted(msg.getBodyAsStringDecoded()); } catch (Exception e) { diff --git a/core/src/main/java/com/predic8/membrane/core/multipart/MultipartUtil.java b/core/src/main/java/com/predic8/membrane/core/multipart/MultipartUtil.java new file mode 100644 index 0000000000..d19a808e3a --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/multipart/MultipartUtil.java @@ -0,0 +1,98 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.multipart; + +import com.predic8.membrane.core.http.Header; +import com.predic8.membrane.core.http.Message; +import com.predic8.membrane.core.util.MessageUtil; +import jakarta.mail.internet.ParseException; +import org.apache.commons.fileupload.MultipartStream; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static java.nio.charset.StandardCharsets.UTF_8; + +/** + * Utility for splitting multipart HTTP messages into their individual {@link Part}s. + * + *

Example: + *

{@code
+ * List parts = MultipartUtil.split(exchange.getRequest());
+ * for (Part part : parts) {
+ *     String name = part.getName();          // form field name
+ *     String type = part.getContentType();   // e.g. "image/png"
+ *     byte[] body = part.getBody();
+ * }
+ * }
+ */ +public class MultipartUtil { + + /** + * Splits a multipart message into its individual parts. + * The MIME boundary is read from the message's {@code Content-Type} header. + * + * @param message a request or response whose Content-Type is multipart/* + * @return parts in wire order; never null, may be empty + * @throws IOException on I/O or parse errors + * @throws ParseException if the Content-Type header cannot be parsed + */ + public static List split(Message message) throws IOException, ParseException { + var contentType = message.getHeader().getContentTypeObject(); + if (contentType == null) { + throw new IOException("No Content-Type header"); + } + String boundary = contentType.getParameter("boundary"); + if (boundary == null) { + throw new IOException("No boundary parameter in Content-Type: " + contentType); + } + return split(message, boundary); + } + + /** + * Splits a multipart message into its individual parts using an explicit boundary. + * + * @param message a request or response with a multipart body + * @param boundary the MIME boundary string (without leading {@code --}) + * @return parts in wire order; never null, may be empty + * @throws IOException on I/O or unsupported Content-Transfer-Encoding + */ + @SuppressWarnings("deprecation") + public static List split(Message message, String boundary) throws IOException { + List result = new ArrayList<>(); + + MultipartStream ms = new MultipartStream(MessageUtil.getContentAsStream(message), boundary.getBytes(UTF_8)); + boolean hasNext = ms.skipPreamble(); + while (hasNext) { + Header partHeader = new Header(ms.readHeaders()); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ms.readBodyData(baos); + + // Only binary-safe encodings are supported; base64/QP would corrupt binary parts + String cte = partHeader.getFirstValue("Content-Transfer-Encoding"); + if (cte != null && !cte.equalsIgnoreCase("binary") + && !cte.equalsIgnoreCase("8bit") + && !cte.equalsIgnoreCase("7bit")) { + throw new IOException("Content-Transfer-Encoding '" + cte + "' is not supported."); + } + + result.add(new Part(partHeader, baos.toByteArray())); + hasNext = ms.readBoundary(); + } + return result; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/multipart/Part.java b/core/src/main/java/com/predic8/membrane/core/multipart/Part.java index 6d3e58c027..5415954eec 100644 --- a/core/src/main/java/com/predic8/membrane/core/multipart/Part.java +++ b/core/src/main/java/com/predic8/membrane/core/multipart/Part.java @@ -14,141 +14,120 @@ package com.predic8.membrane.core.multipart; -import com.predic8.membrane.core.http.*; +import com.predic8.membrane.core.http.Header; -import javax.xml.namespace.*; -import javax.xml.stream.*; -import javax.xml.stream.events.*; -import java.io.*; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.util.regex.Matcher; +import java.util.regex.Pattern; -import static java.nio.charset.StandardCharsets.*; -import static org.apache.commons.codec.binary.Base64.*; +import static java.nio.charset.StandardCharsets.UTF_8; +/** + * A single part of a multipart HTTP message, consisting of a header block and a body. + * + * @see MultipartUtil#split(com.predic8.membrane.core.http.Message) + */ public class Part { - private final Header header; - private final byte[] data; - - public Part(Header header, byte[] data) { - this.header = header; - this.data = data; - } - - public String getContentID() { - return header.getFirstValue("Content-ID"); - } - - public Header getHeader() { - return header; - } - - public InputStream getInputStream() { - return new ByteArrayInputStream(data); - } - - public XMLEvent asXMLEvent() { - return new Characters() { - - @Override - public void writeAsEncodedUnicode(Writer writer) { - throw new RuntimeException("not implemented"); - } - - @Override - public boolean isStartElement() { - return false; - } - - @Override - public boolean isStartDocument() { - return false; - } - - @Override - public boolean isProcessingInstruction() { - return false; - } - - @Override - public boolean isNamespace() { - return false; - } - - @Override - public boolean isEntityReference() { - return false; - } - - @Override - public boolean isEndElement() { - return false; - } - - @Override - public boolean isEndDocument() { - return false; - } - - @Override - public boolean isCharacters() { - return true; - } - - @Override - public boolean isAttribute() { - return false; - } - - @Override - public QName getSchemaType() { - return null; - } - - @Override - public Location getLocation() { - return null; - } - - @Override - public int getEventType() { - return CHARACTERS; - } - - @Override - public StartElement asStartElement() { - return null; - } - - @Override - public EndElement asEndElement() { - return null; - } - - @Override - public Characters asCharacters() { - return this; - } - - @Override - public String getData() { - return new String(encodeBase64(data), UTF_8); - } - - @Override - public boolean isWhiteSpace() { - return false; - } - - @Override - public boolean isCData() { - return false; - } - - @Override - public boolean isIgnorableWhiteSpace() { - return false; - } - }; - } - + private static final Pattern NAME_PATTERN = + Pattern.compile("(?i)\\bname=\"([^\"]+)\""); + private static final Pattern FILENAME_PATTERN = + Pattern.compile("(?i)\\bfilename=\"([^\"]+)\""); + + private final Header header; + private final byte[] body; + + public Part(Header header, byte[] body) { + this.header = header; + this.body = body; + } + + // ------------------------------------------------------------------------- + // Header accessors + // ------------------------------------------------------------------------- + + /** + * Returns the part's own header block (may contain Content-Type, Content-ID, etc.). + */ + public Header getHeader() { + return header; + } + + /** + * Returns the {@code Content-ID} header value, or {@code null} if absent. + * Used in MIME multipart/related messages (e.g. SOAP XOP). + */ + public String getContentID() { + return header.getFirstValue("Content-ID"); + } + + /** + * Returns the {@code Content-Type} of this part (e.g. {@code "image/png"}), + * or {@code null} if no Content-Type header is present. + */ + public String getContentType() { + return header.getContentType(); + } + + /** + * Returns the {@code name} parameter from the {@code Content-Disposition} header. + * This is the form field name in {@code multipart/form-data} submissions. + * Returns {@code null} if not present. + */ + public String getName() { + return extractDispositionParam(NAME_PATTERN); + } + + /** + * Returns the {@code filename} parameter from the {@code Content-Disposition} header, + * or {@code null} if not present. + */ + public String getFilename() { + return extractDispositionParam(FILENAME_PATTERN); + } + + // ------------------------------------------------------------------------- + // Body accessors + // ------------------------------------------------------------------------- + + /** + * Returns the raw body bytes of this part. + */ + public byte[] getBody() { + return body; + } + + /** + * Returns the body decoded as a UTF-8 string. + */ + public String getBodyAsString() { + return getBodyAsString(UTF_8); + } + + /** + * Returns the body decoded using the given charset. + */ + public String getBodyAsString(Charset charset) { + return new String(body, charset); + } + + /** + * Returns a fresh {@link InputStream} over the body bytes. + */ + public InputStream getInputStream() { + return new ByteArrayInputStream(body); + } + + // ------------------------------------------------------------------------- + // Internal helpers + // ------------------------------------------------------------------------- + + private String extractDispositionParam(Pattern pattern) { + String disposition = header.getFirstValue("Content-Disposition"); + if (disposition == null) return null; + Matcher m = pattern.matcher(disposition); + return m.find() ? m.group(1) : null; + } } diff --git a/core/src/main/java/com/predic8/membrane/core/multipart/XOPReconstitutor.java b/core/src/main/java/com/predic8/membrane/core/multipart/XOPReconstitutor.java index e9aa70e6e2..1a5e14ee8e 100644 --- a/core/src/main/java/com/predic8/membrane/core/multipart/XOPReconstitutor.java +++ b/core/src/main/java/com/predic8/membrane/core/multipart/XOPReconstitutor.java @@ -14,20 +14,31 @@ package com.predic8.membrane.core.multipart; -import com.predic8.membrane.core.http.*; -import com.predic8.membrane.core.util.*; -import jakarta.mail.internet.*; -import org.apache.commons.fileupload.*; -import org.slf4j.*; - -import javax.annotation.concurrent.*; -import javax.xml.namespace.*; +import com.predic8.membrane.core.http.BodyCollectingMessageObserver; +import com.predic8.membrane.core.http.Header; +import com.predic8.membrane.core.http.Message; +import com.predic8.membrane.core.util.EndOfStreamException; +import com.predic8.membrane.core.util.MessageUtil; +import jakarta.mail.internet.ContentType; +import jakarta.mail.internet.ParseException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.ThreadSafe; +import javax.xml.namespace.QName; import javax.xml.stream.*; -import javax.xml.stream.events.*; -import java.io.*; -import java.util.*; - -import static java.nio.charset.StandardCharsets.*; +import javax.xml.stream.events.Characters; +import javax.xml.stream.events.EndElement; +import javax.xml.stream.events.StartElement; +import javax.xml.stream.events.XMLEvent; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.Writer; +import java.util.HashMap; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.apache.commons.codec.binary.Base64.encodeBase64; /** * Reassemble a multipart XOP message (see @@ -90,7 +101,7 @@ public Message getReconstitutedMessage(Message message) throws ParseException, I if (boundary == null) return null; - HashMap parts = split(message, boundary); + HashMap parts = splitById(message, boundary); Part startPart = parts.get(start); if (startPart == null) return null; @@ -132,36 +143,16 @@ public boolean shouldNotContainBody() { return m; } - @SuppressWarnings("deprecation") - private HashMap split(Message message, String boundary) - throws IOException, EndOfStreamException { - HashMap parts = new HashMap<>(); - - MultipartStream multipartStream = new MultipartStream(MessageUtil.getContentAsStream(message), boundary.getBytes(UTF_8)); - boolean nextPart = multipartStream.skipPreamble(); - while(nextPart) { - Header header = new Header(multipartStream.readHeaders()); - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - multipartStream.readBodyData(baos); - - // see http://www.iana.org/assignments/transfer-encodings/transfer-encodings.xml - String cte = header.getFirstValue("Content-Transfer-Encoding"); - if (cte != null && - !cte.equals("binary") && - !cte.equals("8bit") && - !cte.equals("7bit")) - throw new RuntimeException("Content-Transfer-Encoding '" + cte + "' not implemented."); - - - Part part = new Part(header, baos.toByteArray()); + /** Splits the multipart message and indexes parts by Content-ID for XOP lookup. */ + private HashMap splitById(Message message, String boundary) throws IOException { + HashMap byId = new HashMap<>(); + for (Part part : MultipartUtil.split(message, boundary)) { String id = part.getContentID(); if (id != null) { - parts.put(id, part); + byId.put(id, part); } - - nextPart = multipartStream.readBoundary(); } - return parts; + return byId; } private byte[] fillInXOPParts(InputStream inputStream, @@ -189,7 +180,7 @@ private byte[] fillInXOPParts(InputStream inputStream, if (p == null) throw new RuntimeException("Did not find multipart with id " + href); - writer.add(p.asXMLEvent()); + writer.add(base64CharactersEvent(p.getBody())); xopIncludeOpen = true; continue; } @@ -212,4 +203,33 @@ private byte[] fillInXOPParts(InputStream inputStream, return baos.toByteArray(); } + /** Wraps raw bytes as a base64-encoded XML Characters event for XOP inlining. */ + private static Characters base64CharactersEvent(byte[] data) { + String encoded = new String(encodeBase64(data), UTF_8); + return new Characters() { + @Override public String getData() { return encoded; } + @Override public boolean isCharacters() { return true; } + @Override public boolean isWhiteSpace() { return false; } + @Override public boolean isCData() { return false; } + @Override public boolean isIgnorableWhiteSpace() { return false; } + @Override public int getEventType() { return CHARACTERS; } + @Override public Characters asCharacters() { return this; } + @Override public boolean isStartElement() { return false; } + @Override public boolean isEndElement() { return false; } + @Override public boolean isStartDocument() { return false; } + @Override public boolean isEndDocument() { return false; } + @Override public boolean isAttribute() { return false; } + @Override public boolean isNamespace() { return false; } + @Override public boolean isEntityReference() { return false; } + @Override public boolean isProcessingInstruction() { return false; } + @Override public QName getSchemaType() { return null; } + @Override public Location getLocation() { return null; } + @Override public StartElement asStartElement() { return null; } + @Override public EndElement asEndElement() { return null; } + @Override public void writeAsEncodedUnicode(Writer writer) { + throw new UnsupportedOperationException(); + } + }; + } + } diff --git a/core/src/test/java/com/predic8/membrane/core/http/HeaderTest.java b/core/src/test/java/com/predic8/membrane/core/http/HeaderTest.java index 2597281778..add23dfc42 100644 --- a/core/src/test/java/com/predic8/membrane/core/http/HeaderTest.java +++ b/core/src/test/java/com/predic8/membrane/core/http/HeaderTest.java @@ -15,15 +15,21 @@ package com.predic8.membrane.core.http; import jakarta.activation.MimeType; -import org.junit.jupiter.api.*; -import org.junit.jupiter.params.*; -import org.junit.jupiter.params.provider.*; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; -import java.util.*; +import java.util.HashSet; +import java.util.List; import static com.predic8.membrane.core.http.Header.*; -import static com.predic8.membrane.core.http.MimeType.*; -import static java.nio.charset.StandardCharsets.*; +import static com.predic8.membrane.core.http.MimeType.TEXT_XML; +import static com.predic8.membrane.core.http.MimeType.isBinary; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.jupiter.api.Assertions.*; class HeaderTest { @@ -262,4 +268,47 @@ void unique() { assertEquals("1, 2", h.getValuesAsString("X-Foo")); assertEquals("3, 4", h.getValuesAsString("X-BAR")); } + + @Nested + class IsMultipart { + @Test + void formDataIsMultipart() { + var h = new Header(); + h.add("Content-Type", "multipart/form-data; boundary=abc"); + assertTrue(h.isMultipart()); + } + + @Test + void relatedIsMultipart() { + var h = new Header(); + h.add("Content-Type", "multipart/related; boundary=abc"); + assertTrue(h.isMultipart()); + } + + @Test + void mixedIsMultipart() { + var h = new Header(); + h.add("Content-Type", "multipart/mixed; boundary=abc"); + assertTrue(h.isMultipart()); + } + + @Test + void isCaseInsensitive() { + var h = new Header(); + h.add("Content-Type", "Multipart/Form-Data; boundary=abc"); + assertTrue(h.isMultipart()); + } + + @Test + void jsonIsNotMultipart() { + var h = new Header(); + h.add("Content-Type", "application/json"); + assertFalse(h.isMultipart()); + } + + @Test + void missingContentTypeIsNotMultipart() { + assertFalse(new Header().isMultipart()); + } + } } diff --git a/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java b/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequestTest.java similarity index 79% rename from core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java rename to core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequestTest.java index af841ffed8..560d5edf58 100644 --- a/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequestTest.java +++ b/core/src/test/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequestTest.java @@ -4,12 +4,14 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import java.io.IOException; import java.net.URISyntaxException; +import java.util.List; import static com.predic8.membrane.core.http.Request.post; import static org.junit.jupiter.api.Assertions.assertEquals; -class AbstractLLMRequestTest { +class AbstractModelInputRequestTest { @ParameterizedTest @ValueSource(strings = { @@ -18,7 +20,7 @@ class AbstractLLMRequestTest { "BEARER test-api-key", "bEaReR test-api-key" }) - void getApiKeyAcceptsBearerCaseInsensitive(String authorization) throws URISyntaxException { + void getApiKeyAcceptsBearerCaseInsensitive(String authorization) throws URISyntaxException, IOException { var request = new TestLLMRequest(post("http://localhost/chat/completions") .header("Authorization", authorization) .json("{}") @@ -27,9 +29,9 @@ void getApiKeyAcceptsBearerCaseInsensitive(String authorization) throws URISynta assertEquals("test-api-key", request.getApiKey()); } - private static class TestLLMRequest extends AbstractLLMRequest { + private static class TestLLMRequest extends AbstractModelInputRequest implements ModelInputRequest { - TestLLMRequest(Exchange exchange) { + TestLLMRequest(Exchange exchange) throws IOException { super(exchange); } @@ -53,12 +55,7 @@ public String getSystemPrompt() { } @Override - public boolean isChatCompletion() { - return false; - } - - @Override - public void setSystemPrompt(String systemPrompt) { + public void setSystemPrompts(List prompts) { } diff --git a/core/src/test/java/com/predic8/membrane/core/multipart/MultipartUtilTest.java b/core/src/test/java/com/predic8/membrane/core/multipart/MultipartUtilTest.java new file mode 100644 index 0000000000..950896246b --- /dev/null +++ b/core/src/test/java/com/predic8/membrane/core/multipart/MultipartUtilTest.java @@ -0,0 +1,229 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.multipart; + +import com.predic8.membrane.core.http.Response; +import jakarta.mail.internet.ParseException; +import org.apache.commons.io.IOUtils; +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.jupiter.api.Assertions.*; + +class MultipartUtilTest { + + private static final String BOUNDARY = "test-boundary-123"; + private static final String CRLF = "\r\n"; + + // ------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------- + + /** Builds a Response with the given multipart body and boundary. */ + private static Response response(String body) { + return response(body, BOUNDARY); + } + + private static Response response(String body, String boundary) { + byte[] bytes = body.getBytes(UTF_8); + return Response.ok() + .header("Content-Type", "multipart/form-data; boundary=\"" + boundary + "\"") + .header("Content-Length", String.valueOf(bytes.length)) + .body(bytes) + .build(); + } + + /** + * Builds a minimal multipart body. + * Each {@code part} string should contain headers + blank line + body, + * e.g. {@code "Content-Disposition: form-data; name=\"x\"\r\n\r\nvalue"}. + */ + private static String multipartBody(String... parts) { + var sb = new StringBuilder(); + for (String part : parts) { + sb.append("--").append(BOUNDARY).append(CRLF); + sb.append(part).append(CRLF); + } + sb.append("--").append(BOUNDARY).append("--").append(CRLF); + return sb.toString(); + } + + private static String formField(String name, String value) { + return "Content-Disposition: form-data; name=\"" + name + "\"" + CRLF + CRLF + value; + } + + // ------------------------------------------------------------------------- + // split(Message) — auto-reads boundary from Content-Type + // ------------------------------------------------------------------------- + + @Test + void twoFormFieldsAreReturnedInOrder() throws IOException, ParseException { + var parts = MultipartUtil.split(response(multipartBody( + formField("username", "alice"), + formField("message", "Hello World") + ))); + + assertEquals(2, parts.size()); + assertEquals("username", parts.get(0).getName()); + assertEquals("alice", parts.get(0).getBodyAsString()); + assertEquals("message", parts.get(1).getName()); + assertEquals("Hello World", parts.get(1).getBodyAsString()); + } + + @Test + void fileUploadPartExposesFilenameAndContentType() throws IOException, ParseException { + String part = "Content-Disposition: form-data; name=\"upload\"; filename=\"photo.jpg\"" + CRLF + + "Content-Type: image/jpeg" + CRLF + + CRLF + + "JFIF"; + + var parts = MultipartUtil.split(response(multipartBody(part))); + + assertEquals(1, parts.size()); + assertEquals("upload", parts.get(0).getName()); + assertEquals("photo.jpg", parts.get(0).getFilename()); + assertEquals("image/jpeg", parts.get(0).getContentType()); + assertArrayEquals("JFIF".getBytes(UTF_8), parts.get(0).getBody()); + } + + @Test + void partWithContentIdIsAccessible() throws IOException, ParseException { + String part = "Content-Type: application/octet-stream" + CRLF + + "Content-ID: " + CRLF + + CRLF + + "binary"; + + var parts = MultipartUtil.split(response(multipartBody(part))); + + assertEquals("", parts.get(0).getContentID()); + } + + @Test + void binaryBodyIsPreservedExactly() throws IOException, ParseException { + byte[] payload = {0, 1, 2, (byte) 0xFF, (byte) 0xFE}; + String header = "Content-Type: application/octet-stream" + CRLF + CRLF; + byte[] partBytes = (header).getBytes(UTF_8); + byte[] fullPart = new byte[partBytes.length + payload.length]; + System.arraycopy(partBytes, 0, fullPart, 0, partBytes.length); + System.arraycopy(payload, 0, fullPart, partBytes.length, payload.length); + + // Build body manually to embed raw bytes + byte[] prefix = ("--" + BOUNDARY + CRLF).getBytes(UTF_8); + byte[] suffix = (CRLF + "--" + BOUNDARY + "--" + CRLF).getBytes(UTF_8); + byte[] body = new byte[prefix.length + fullPart.length + suffix.length]; + System.arraycopy(prefix, 0, body, 0, prefix.length); + System.arraycopy(fullPart, 0, body, prefix.length, fullPart.length); + System.arraycopy(suffix, 0, body, prefix.length + fullPart.length, suffix.length); + + byte[] msgBytes = body; + var msg = Response.ok() + .header("Content-Type", "multipart/form-data; boundary=\"" + BOUNDARY + "\"") + .header("Content-Length", String.valueOf(msgBytes.length)) + .body(msgBytes) + .build(); + + var parts = MultipartUtil.split(msg); + assertArrayEquals(payload, parts.get(0).getBody()); + } + + // ------------------------------------------------------------------------- + // split(Message, boundary) — explicit boundary overload + // ------------------------------------------------------------------------- + + @Test + void explicitBoundaryOverloadProducesSameResult() throws IOException { + var body = multipartBody(formField("x", "42")); + byte[] bytes = body.getBytes(UTF_8); + var msg = Response.ok() + .header("Content-Type", "multipart/form-data; boundary=\"other\"") // intentionally wrong + .header("Content-Length", String.valueOf(bytes.length)) + .body(bytes) + .build(); + + // Pass the correct boundary explicitly — Content-Type boundary is ignored + var parts = MultipartUtil.split(msg, BOUNDARY); + + assertEquals(1, parts.size()); + assertEquals("x", parts.get(0).getName()); + assertEquals("42", parts.get(0).getBodyAsString()); + } + + // ------------------------------------------------------------------------- + // Real-world resource: XOP multipart from ReassembleTest + // ------------------------------------------------------------------------- + + @SuppressWarnings("DataFlowIssue") + @Test + void xopResourceSplitsIntoTwoParts() throws IOException { + byte[] body = IOUtils.toByteArray(getClass().getResourceAsStream("/multipart/embedded-byte-array.txt")); + var response = Response.ok() + .header("Content-Type", "multipart/related; " + + "type=\"application/xop+xml\"; " + + "boundary=\"uuid:168683dc-43b3-4e71-8e66-efb633ef406b\"; " + + "start=\"\"; " + + "start-info=\"text/xml\"") + .header("Content-Length", String.valueOf(body.length)) + .body(body) + .build(); + + var parts = MultipartUtil.split(response, "uuid:168683dc-43b3-4e71-8e66-efb633ef406b"); + + assertEquals(2, parts.size()); + assertEquals("", parts.get(0).getContentID()); + assertEquals("", parts.get(1).getContentID()); + assertEquals("application/xop+xml; charset=UTF-8; type=\"text/xml\";", parts.get(0).getContentType()); + assertEquals("application/octet-stream", parts.get(1).getContentType()); + } + + // ------------------------------------------------------------------------- + // Error cases + // ------------------------------------------------------------------------- + + @Test + void missingContentTypeThrows() { + byte[] bytes = "body".getBytes(UTF_8); + var msg = Response.ok() + .header("Content-Length", String.valueOf(bytes.length)) + .body(bytes) + .build(); + + assertThrows(IOException.class, () -> MultipartUtil.split(msg)); + } + + @Test + void missingBoundaryParameterThrows() { + byte[] bytes = "body".getBytes(UTF_8); + var msg = Response.ok() + .header("Content-Type", "multipart/form-data") // no boundary= + .header("Content-Length", String.valueOf(bytes.length)) + .body(bytes) + .build(); + + assertThrows(IOException.class, () -> MultipartUtil.split(msg)); + } + + @Test + void unsupportedContentTransferEncodingThrows() { + String part = "Content-Disposition: form-data; name=\"x\"" + CRLF + + "Content-Transfer-Encoding: quoted-printable" + CRLF + + CRLF + + "value"; + + assertThrows(IOException.class, + () -> MultipartUtil.split(response(multipartBody(part)))); + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/BasicClaudeLLMGatewayTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/BasicClaudeLLMGatewayTutorialTest.java index 3cde3fa976..32614a7431 100644 --- a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/BasicClaudeLLMGatewayTutorialTest.java +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/claude/BasicClaudeLLMGatewayTutorialTest.java @@ -62,6 +62,7 @@ void simpleRequestIsForwarded() throws IOException { .when() .post(LOCALHOST_2000 + "/v1/messages") .then() + .log().ifValidationFails() .statusCode(200) .body("type", equalTo("message")) .body("content[0].type", equalTo("text")); @@ -106,6 +107,7 @@ void outputTokensAreCappedBeforeForwarding() throws IOException { .when() .post(LOCALHOST_2000 + "/v1/messages") .then() + .log().everything() .statusCode(200); // @formatter:on diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java index 88a6d380ad..e1821bc28c 100644 --- a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/llmgateway/openai/SharingApiKeysOpenAiTutorialTest.java @@ -123,6 +123,7 @@ void wrongModelIsRejected() throws IOException { .when() .post(LOCALHOST_2000 + "/v1/responses") .then() + .log().ifValidationFails() .statusCode(400) .body("error.type", equalTo("invalid_request_error")) .body("error.code", equalTo("model_not_allowed")) From a7caaa4a1043eaa451f7009835da28c6d573be48 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 May 2026 15:09:49 +0200 Subject: [PATCH 41/43] feat: add `AbstractModelInputRequest` and enhance OpenAI provider request handling - Introduced `AbstractModelInputRequest` to support model input parsing for JSON and multipart requests. - Added `ChatCompletionsRequest` support in `OpenAIProvider` to handle `/v1/chat/completions` URI. - Simplified `DefaultPolicies` by consolidating `checkModel` call. - Minor cleanup in `MultipartUtilTest` and improved method signature in `XOPReconstitutor`. --- .../llmgateway/DefaultPolicies.java | 6 +- .../provider/AbstractModelInputRequest.java | 146 ++++++++++++++++++ .../provider/openai/OpenAIProvider.java | 3 + .../core/multipart/XOPReconstitutor.java | 2 +- .../core/multipart/MultipartUtilTest.java | 1 - 5 files changed, 151 insertions(+), 7 deletions(-) create mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java index 7c6f9612c4..1219fb66b3 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java @@ -58,11 +58,7 @@ public Outcome handleRequest(ModelInputRequest mir, Exchange exc) { if (outcome != CONTINUE) { return outcome; } - outcome = checkModel(mir, exc); - if (outcome != CONTINUE) { - return outcome; - } - return CONTINUE; + return checkModel(mir, exc); } public Outcome checkModel(ModelInputRequest mir, Exchange exc) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java new file mode 100644 index 0000000000..2fb780f214 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java @@ -0,0 +1,146 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.llmgateway.provider; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.http.AbstractBody; +import com.predic8.membrane.core.http.Body; +import com.predic8.membrane.core.multipart.MultipartUtil; +import com.predic8.membrane.core.util.json.JsonUtil; +import jakarta.mail.internet.ParseException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static java.nio.charset.StandardCharsets.UTF_8; + +public class AbstractModelInputRequest extends AbstractLLMRequest implements ModelInputRequest { + + private static final Logger log = LoggerFactory.getLogger(AbstractModelInputRequest.class); + + private static final ObjectMapper om = new ObjectMapper(); + + protected ObjectNode json; + + private String model; + + private AbstractBody body; + + public AbstractModelInputRequest(Exchange exchange) throws IOException { + super(exchange); + + if (exchange.getRequest().getHeader().isMultipart()) { + try { + for (var part : MultipartUtil.split(exchange.getRequest())) { + log.info("Part: name={} type={} size={}", part.getName(), part.getContentType(), part.getBody().length); + if ("model".equals(part.getName())) { + log.info("Model: {}", part.getBodyAsString()); + model = part.getBodyAsString(); + } + } + body = exchange.getRequest().getBody(); + } catch (IOException e) { + throw new RuntimeException(e); + } catch (ParseException e) { + throw new RuntimeException(e); + } + return; + } + + if (exchange.getRequest().isJSON()) { + json = JsonUtil.getJsonObject(exchange.getRequest()).orElseThrow(() -> new RuntimeException("Cannot parse input as JSON message.")); + } + + if (json != null) { + if (json.has("model")) { + model = json.path("model").asText(); + } + } + } + + public List getTools() { + return Collections.emptyList(); + } + + @Override + public String getSystemPrompt() { + return ""; + } + + @Override + public void setSystemPrompts(List prompts) { + log.warn("Not supported."); + } + + @Override + public void removeSystemPrompt() { + log.warn("Not supported."); + } + + protected ArrayNode getToolsNode() { + if (json == null) + return null; + if (json.path("tools").isArray()) + return (ArrayNode) json.path("tools"); + return null; + } + + + @Override + public ObjectNode getJson() { + return json; + } + + @Override + public String getModel() { + return model; + } + + @Override + public long getRequestedMaxOutputTokens() { + return 0; + } + + @Override + public void setMaxOutputTokens(int maxOutputTokens) { + log.warn("Not supported."); + } + + @Override + public long estimateInputTokens() { + return 0; + } + + @Override + public AbstractBody getBody() { + if (body != null) + return body; + try { + return new Body(om + .writerWithDefaultPrettyPrinter() + .writeValueAsString(json).getBytes(UTF_8)); + } catch (JsonProcessingException e) { + log.info("Could not serialize JSON: {}", e.getMessage()); + throw new RuntimeException("Could not serialize JSON: " + e.getMessage()); + } + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java index 9798483ed5..d24e5154c1 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/OpenAIProvider.java @@ -36,6 +36,9 @@ public class OpenAIProvider extends AbstractLLMProvider { @Override public LLMRequest getLLMRequest(Exchange exchange) throws IOException { var uri = exchange.getRequest().getUri(); + if (uri.startsWith("/v1/chat/completions")) { + return new OpenAIChatCompletionsRequest(exchange); + } if (uri.startsWith("/v1/responses")) { return new OpenAiLLMResponsesRequest(exchange); } diff --git a/core/src/main/java/com/predic8/membrane/core/multipart/XOPReconstitutor.java b/core/src/main/java/com/predic8/membrane/core/multipart/XOPReconstitutor.java index 1a5e14ee8e..c2253e333a 100644 --- a/core/src/main/java/com/predic8/membrane/core/multipart/XOPReconstitutor.java +++ b/core/src/main/java/com/predic8/membrane/core/multipart/XOPReconstitutor.java @@ -59,7 +59,7 @@ public XOPReconstitutor() { xmlInputFactory.setProperty(XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES, false); } - public InputStream reconstituteIfNecessary(Message message) throws IOException { + public InputStream reconstituteIfNecessary(Message message) { try { Message reconstitutedMessage = getReconstitutedMessage(message); if (reconstitutedMessage != null) diff --git a/core/src/test/java/com/predic8/membrane/core/multipart/MultipartUtilTest.java b/core/src/test/java/com/predic8/membrane/core/multipart/MultipartUtilTest.java index 950896246b..86c8fbf1e3 100644 --- a/core/src/test/java/com/predic8/membrane/core/multipart/MultipartUtilTest.java +++ b/core/src/test/java/com/predic8/membrane/core/multipart/MultipartUtilTest.java @@ -166,7 +166,6 @@ void explicitBoundaryOverloadProducesSameResult() throws IOException { // Real-world resource: XOP multipart from ReassembleTest // ------------------------------------------------------------------------- - @SuppressWarnings("DataFlowIssue") @Test void xopResourceSplitsIntoTwoParts() throws IOException { byte[] body = IOUtils.toByteArray(getClass().getResourceAsStream("/multipart/embedded-byte-array.txt")); From 96bc067efb2bc3c622574ea26c31df36450a37df Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 5 Jun 2026 08:07:42 +0200 Subject: [PATCH 42/43] refactor: remove specialized request classes and simplify request handling with `BaseLLMRequest` --- .../provider/AbstractLLMProvider.java | 19 +------------------ .../provider/AbstractModelInputRequest.java | 2 +- ...actLLMRequest.java => BaseLLMRequest.java} | 4 ++-- .../provider/openai/AudioRequest.java | 12 ------------ .../provider/openai/FilesRequest.java | 12 ------------ .../provider/openai/ImagesRequest.java | 11 ----------- 6 files changed, 4 insertions(+), 56 deletions(-) rename core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/{AbstractLLMRequest.java => BaseLLMRequest.java} (87%) delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AudioRequest.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/FilesRequest.java delete mode 100644 core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/ImagesRequest.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMProvider.java index 6273c2c5db..9b833dea56 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMProvider.java @@ -1,12 +1,7 @@ package com.predic8.membrane.core.interceptor.llmgateway.provider; import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.http.ReadingBodyException; import com.predic8.membrane.core.interceptor.llmgateway.provider.chatcompletions.ChatCompletionsRequest; -import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.AudioRequest; -import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.FilesRequest; -import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.ImagesRequest; -import com.predic8.membrane.core.interceptor.llmgateway.provider.openai.OrganizationRequest; import java.io.IOException; @@ -18,18 +13,6 @@ public LLMRequest getLLMRequest(Exchange exchange) throws IOException { if (uri.startsWith("/v1/chat/completions")) { return new ChatCompletionsRequest(exchange); } - if (uri.startsWith("/v1/files")) { - return new FilesRequest(exchange); - } - if (uri.contains("/v1/images")) { - return new ImagesRequest(exchange); - } - if (uri.contains("/v1/audio")) { - return new AudioRequest(exchange); - } - if (uri.contains("/v1/organization")) { - return new OrganizationRequest(exchange); - } - throw new ReadingBodyException("Unknown request: " + uri); + return new BaseLLMRequest(exchange); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java index 2fb780f214..8551dcbc4a 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java @@ -33,7 +33,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; -public class AbstractModelInputRequest extends AbstractLLMRequest implements ModelInputRequest { +public class AbstractModelInputRequest extends BaseLLMRequest implements ModelInputRequest { private static final Logger log = LoggerFactory.getLogger(AbstractModelInputRequest.class); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/BaseLLMRequest.java similarity index 87% rename from core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java rename to core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/BaseLLMRequest.java index 4c7c8fded3..40b317de09 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractLLMRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/BaseLLMRequest.java @@ -4,11 +4,11 @@ import static com.predic8.membrane.core.http.Header.AUTHORIZATION; -public class AbstractLLMRequest extends AbstractLLMMessage implements LLMRequest { +public class BaseLLMRequest extends AbstractLLMMessage implements LLMRequest { public static final String BEARER_PREFIX = "Bearer"; - protected AbstractLLMRequest(Exchange exchange) { + protected BaseLLMRequest(Exchange exchange) { super(exchange); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AudioRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AudioRequest.java deleted file mode 100644 index df028457a3..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/AudioRequest.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractModelInputRequest; - -import java.io.IOException; - -public class AudioRequest extends AbstractModelInputRequest { - public AudioRequest(Exchange exchange) throws IOException { - super(exchange); - } -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/FilesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/FilesRequest.java deleted file mode 100644 index cc56d7b492..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/FilesRequest.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; - -public class FilesRequest extends AbstractLLMRequest { - - public FilesRequest(Exchange exchange) { - super(exchange); - } - -} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/ImagesRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/ImagesRequest.java deleted file mode 100644 index 74ac706d04..0000000000 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/openai/ImagesRequest.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.predic8.membrane.core.interceptor.llmgateway.provider.openai; - -import com.predic8.membrane.core.exchange.Exchange; -import com.predic8.membrane.core.interceptor.llmgateway.provider.AbstractLLMRequest; - -public class ImagesRequest extends AbstractLLMRequest { - - public ImagesRequest(Exchange exchange) { - super(exchange); - } -} From 398782a1551342450e99a37da7b08e039a244115 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 9 Jun 2026 16:50:54 +0200 Subject: [PATCH 43/43] fix: correct typos, improve exception handling, and update default token logic - Fixed typos in YAML tutorial and Javadoc comments (e.g., "Antropic" to "Anthropic"). - Replaced `IllegalArgumentException` with `ConfigurationException` for token validation in `DefaultPolicies`. - Updated `getRequestedMaxOutputTokens` to return `-1` instead of `0` for default behavior. - Added exception for unknown event format in `AbstractLLMEvent`. --- .../core/interceptor/llmgateway/AbstractLLMEvent.java | 1 + .../membrane/core/interceptor/llmgateway/DefaultPolicies.java | 4 ++-- .../llmgateway/provider/AbstractModelInputRequest.java | 2 +- .../llmgateway/provider/claude/ClaudeProvider.java | 2 +- .../tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/AbstractLLMEvent.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/AbstractLLMEvent.java index ed9fe0929c..c873d946cf 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/AbstractLLMEvent.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/AbstractLLMEvent.java @@ -48,6 +48,7 @@ public static AbstractLLMEvent create(SSEParser.SSEEvent sse) { var opt = JsonUtil.getJsonObject(sse.data()); if (opt.isEmpty()) { log.info("Unknown event format: {}", sse.data()); + throw new RuntimeException("Unknown event format: " + sse.data()); } var json = opt.get(); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java index 1219fb66b3..bedbe1627c 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/DefaultPolicies.java @@ -101,7 +101,7 @@ public List getModels() { /** * @param models List of models that can be used by the gateway. - * @desciption Restricts the models that can be used by the gateway. + * @description Restricts the models that can be used by the gateway. * @default null (no restriction) */ @MCAttribute @@ -123,7 +123,7 @@ public int getMaxOutputTokens() { @MCAttribute public void setMaxOutputTokens(int maxOutputTokens) { if (maxOutputTokens < 0) { - throw new IllegalArgumentException("maxOutputTokens must be >= 0"); + throw new ConfigurationException("maxOutputTokens must be >= 0"); } this.maxOutputTokens = maxOutputTokens; } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java index 8551dcbc4a..0ba3b4c560 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/AbstractModelInputRequest.java @@ -117,7 +117,7 @@ public String getModel() { @Override public long getRequestedMaxOutputTokens() { - return 0; + return -1; } @Override diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java index decc7048b2..e07822b119 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/llmgateway/provider/claude/ClaudeProvider.java @@ -25,7 +25,7 @@ import java.util.function.Consumer; /** - * @description (Experimental) Anthroic Claude provider configuration + * @description (Experimental) Anthropic Claude provider configuration * Use to configure a LLM gateway to use the anthropic API */ @MCElement( name="claude") diff --git a/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml b/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml index ddaaaedcf1..ff12b0f9c8 100644 --- a/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml +++ b/distribution/tutorials/ai/llm-gateway/claude/10-Basic-LLM-Gateway.yaml @@ -1,6 +1,6 @@ # yaml-language-server: $schema=https://www.membrane-api.io/v7.2.1.json # -# Tutorial: Basic LLM Gateway (Antropic Claude) +# Tutorial: Basic LLM Gateway (Anthropic Claude) # # Replace <> with your Claude API key. #