Skip to content

Commit 7eab143

Browse files
committed
use proper mustache library to render prompt templates
1 parent 8edc0a1 commit 7eab143

6 files changed

Lines changed: 336 additions & 109 deletions

File tree

build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ dependencies {
6666
implementation 'org.apache.commons:commons-lang3:3.14.0'
6767
implementation 'com.google.code.findbugs:jsr305:3.0.2' // for @Nullable annotations
6868

69+
implementation "com.github.spullara.mustache.java:compiler:0.9.14"
70+
6971
testImplementation "org.slf4j:slf4j-simple:${slf4jVersion}"
7072
testImplementation "io.opentelemetry:opentelemetry-sdk-testing:${otelVersion}"
7173
testImplementation "org.junit.jupiter:junit-jupiter:${junitVersion}"

src/main/java/dev/braintrust/instrumentation/openai/BraintrustOpenAI.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public static OpenAIClient wrapOpenAI(OpenTelemetry openTelemetry, OpenAIClient
3131

3232
@SneakyThrows
3333
public static ChatCompletionCreateParams buildChatCompletionsPrompt(
34-
BraintrustPrompt prompt, Map<String, String> parameters) {
34+
BraintrustPrompt prompt, Map<String, Object> parameters) {
3535
var promptMap = new HashMap<>(prompt.getOptions());
3636
promptMap.put("messages", prompt.renderMessages(parameters));
3737
var promptJson = ObjectMappers.jsonMapper().writeValueAsString(promptMap);

src/main/java/dev/braintrust/prompt/BraintrustPrompt.java

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
package dev.braintrust.prompt;
22

3+
import com.github.mustachejava.DefaultMustacheFactory;
4+
import com.github.mustachejava.Mustache;
5+
import com.github.mustachejava.MustacheException;
36
import dev.braintrust.api.BraintrustApiClient;
7+
import java.io.StringReader;
8+
import java.io.StringWriter;
49
import java.util.ArrayList;
510
import java.util.HashMap;
6-
import java.util.HashSet;
711
import java.util.List;
812
import java.util.Map;
9-
import java.util.Set;
10-
import java.util.regex.Matcher;
11-
import java.util.regex.Pattern;
1213

1314
public class BraintrustPrompt {
14-
private static final Pattern MUSTACHE_PATTERN = Pattern.compile("\\{\\{([^}]+)\\}\\}");
15-
1615
private final BraintrustApiClient.Prompt apiPrompt;
1716
private final Map<String, String> defaults;
1817

@@ -25,7 +24,7 @@ public BraintrustPrompt(BraintrustApiClient.Prompt apiPrompt, Map<String, String
2524
this.defaults = defaults;
2625
}
2726

28-
public List<Map<String, Object>> renderMessages(Map<String, String> parameters) {
27+
public List<Map<String, Object>> renderMessages(Map<String, Object> parameters) {
2928
// get promptData->prompt->messages
3029
Map<String, Object> promptData = (Map<String, Object>) apiPrompt.promptData().prompt();
3130
List<Map<String, Object>> messages = (List<Map<String, Object>>) promptData.get("messages");
@@ -34,28 +33,20 @@ public List<Map<String, Object>> renderMessages(Map<String, String> parameters)
3433
throw new RuntimeException("No messages found in prompt data");
3534
}
3635

37-
Set<String> usedParameters = new HashSet<>();
3836
List<Map<String, Object>> renderedMessages = new ArrayList<>();
3937

4038
for (Map<String, Object> message : messages) {
4139
Map<String, Object> renderedMessage = new HashMap<>(message);
4240
String content = (String) message.get("content");
4341

4442
if (content != null) {
45-
String renderedContent = renderTemplate(content, parameters, usedParameters);
43+
String renderedContent = renderTemplate(content, parameters);
4644
renderedMessage.put("content", renderedContent);
4745
}
4846

4947
renderedMessages.add(renderedMessage);
5048
}
5149

52-
// Check if all parameters were used
53-
Set<String> unusedParameters = new HashSet<>(parameters.keySet());
54-
unusedParameters.removeAll(usedParameters);
55-
if (!unusedParameters.isEmpty()) {
56-
throw new RuntimeException("Unused parameters: " + unusedParameters);
57-
}
58-
5950
return renderedMessages;
6051
}
6152

@@ -92,24 +83,22 @@ public Map<String, Object> getOptions() {
9283
return result;
9384
}
9485

95-
private String renderTemplate(
96-
String template, Map<String, String> parameters, Set<String> usedParameters) {
97-
Matcher matcher = MUSTACHE_PATTERN.matcher(template);
98-
StringBuffer result = new StringBuffer();
99-
100-
while (matcher.find()) {
101-
String paramName = matcher.group(1);
102-
String paramValue = parameters.get(paramName);
103-
104-
if (paramValue == null) {
105-
throw new RuntimeException("Missing parameter: " + paramName);
86+
private String renderTemplate(String template, Map<String, Object> parameters) {
87+
try {
88+
DefaultMustacheFactory factory = new DefaultMustacheFactory();
89+
Mustache mustache = factory.compile(new StringReader(template), "template");
90+
StringWriter writer = new StringWriter();
91+
mustache.execute(writer, parameters);
92+
writer.flush();
93+
return writer.toString();
94+
} catch (MustacheException e) {
95+
// If the template is malformed, just return it as-is
96+
return template;
97+
} catch (Exception e) {
98+
if (e instanceof RuntimeException) {
99+
throw (RuntimeException) e;
106100
}
107-
108-
usedParameters.add(paramName);
109-
matcher.appendReplacement(result, Matcher.quoteReplacement(paramValue));
101+
throw new RuntimeException("Failed to render template", e);
110102
}
111-
112-
matcher.appendTail(result);
113-
return result.toString();
114103
}
115104
}

src/test/java/dev/braintrust/instrumentation/openai/BraintrustOpenAITest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ void testBuildChatCompletionsPrompt() {
498498

499499
BraintrustPrompt prompt = new BraintrustPrompt(promptObject);
500500

501-
Map<String, String> parameters = Map.of("name", "Alice");
501+
Map<String, Object> parameters = Map.of("name", "Alice");
502502
ChatCompletionCreateParams renderedParams =
503503
BraintrustOpenAI.buildChatCompletionsPrompt(prompt, parameters);
504504

src/test/java/dev/braintrust/prompt/BraintrustPromptLoaderTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void testLoadPromptBySlug() {
4747
assertNotNull(prompt);
4848

4949
// Test rendering
50-
Map<String, String> parameters = Map.of("name", "Bob");
50+
Map<String, Object> parameters = Map.of("name", "Bob");
5151
List<Map<String, Object>> renderedMessages = prompt.renderMessages(parameters);
5252

5353
assertEquals(2, renderedMessages.size());

0 commit comments

Comments
 (0)