Skip to content

Commit c6aadd9

Browse files
authored
feat: add tools to AI (#1489)
* feat: update chatgpt to have an opt-in thinking mode. This allows for the AI to have tool usage. The tool included with this PR is web search.
1 parent 021cad3 commit c6aadd9

19 files changed

Lines changed: 1103 additions & 28 deletions

.gitattributes

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
# This will do normalization to LF on index (staging area)
2-
* text=auto
1+
# Normalize to LF in the index AND in the working tree on all platforms.
2+
# Spotless enforces LF, so without eol=lf, Windows/WSL checkouts get CRLF and
3+
# every spotlessApply rewrites every file.
4+
* text=auto eol=lf
35

46
# Explicit for linux files
57
*.sh text eol=lf

application/config.json.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@
162162
"logInfoChannelWebhook": "<put_your_webhook_here>",
163163
"logErrorChannelWebhook": "<put_your_webhook_here>",
164164
"openaiApiKey": "<check pins in #tjbot_discussion for the key>",
165+
"tavilyApiKey": "<create an account on https://www.tavily.com/ to get an API key for free>",
165166
"sourceCodeBaseUrl": "https://github.com/Together-Java/TJ-Bot/blob/master/application/src/main/java/",
166167
"jshell": {
167168
"baseUrl": "<put_jshell_rest_api_url_here>",

application/src/main/java/org/togetherjava/tjbot/config/Config.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ public final class Config {
5252
private final QuoteBoardConfig quoteBoardConfig;
5353
private final TopHelpersConfig topHelpers;
5454
private final DynamicVoiceChatConfig dynamicVoiceChatConfig;
55+
private final String tavilyApiKey;
5556

5657
@SuppressWarnings("ConstructorWithTooManyParameters")
5758
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
@@ -111,7 +112,8 @@ private Config(@JsonProperty(value = "token", required = true) String token,
111112
required = true) RoleApplicationSystemConfig roleApplicationSystemConfig,
112113
@JsonProperty(value = "topHelpers", required = true) TopHelpersConfig topHelpers,
113114
@JsonProperty(value = "dynamicVoiceChatConfig",
114-
required = true) DynamicVoiceChatConfig dynamicVoiceChatConfig) {
115+
required = true) DynamicVoiceChatConfig dynamicVoiceChatConfig,
116+
@JsonProperty(value = "tavilyApiKey", required = true) String tavilyApiKey) {
115117
this.token = Objects.requireNonNull(token);
116118
this.githubApiKey = Objects.requireNonNull(githubApiKey);
117119
this.databasePath = Objects.requireNonNull(databasePath);
@@ -150,6 +152,7 @@ private Config(@JsonProperty(value = "token", required = true) String token,
150152
this.roleApplicationSystemConfig = roleApplicationSystemConfig;
151153
this.topHelpers = Objects.requireNonNull(topHelpers);
152154
this.dynamicVoiceChatConfig = Objects.requireNonNull(dynamicVoiceChatConfig);
155+
this.tavilyApiKey = Objects.requireNonNull(tavilyApiKey);
153156
}
154157

155158
/**
@@ -499,4 +502,16 @@ public TopHelpersConfig getTopHelpers() {
499502
public DynamicVoiceChatConfig getDynamicVoiceChatConfig() {
500503
return dynamicVoiceChatConfig;
501504
}
505+
506+
/**
507+
* Gets the API key for Tavily ({@link <a href="https://www.tavily.com">tavily.com</a>}), a
508+
* search engine API tailored for LLMs. It is used by the ChatGPT command to power the AI-driven
509+
* web search tool, allowing the assistant to fetch up-to-date information from the web when
510+
* answering user questions.
511+
*
512+
* @return the Tavily API key
513+
*/
514+
public String getTavilyApiKey() {
515+
return tavilyApiKey;
516+
}
502517
}

application/src/main/java/org/togetherjava/tjbot/features/Features.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import org.togetherjava.tjbot.features.bookmarks.LeftoverBookmarksListener;
2020
import org.togetherjava.tjbot.features.chatgpt.ChatGptCommand;
2121
import org.togetherjava.tjbot.features.chatgpt.ChatGptService;
22+
import org.togetherjava.tjbot.features.chatgpt.tools.web.FetchUrlTool;
23+
import org.togetherjava.tjbot.features.chatgpt.tools.web.WebSearchTool;
2224
import org.togetherjava.tjbot.features.code.CodeMessageAutoDetection;
2325
import org.togetherjava.tjbot.features.code.CodeMessageHandler;
2426
import org.togetherjava.tjbot.features.code.CodeMessageManualDetection;
@@ -88,6 +90,7 @@
8890

8991
import java.util.ArrayList;
9092
import java.util.Collection;
93+
import java.util.List;
9194

9295
/**
9396
* Utility class that offers all features that should be registered by the system, such as commands.
@@ -218,7 +221,10 @@ public static Collection<Feature> createFeatures(JDA jda, Database database, Con
218221
features.add(new HelpThreadCommand(config, helpSystemHelper, metrics));
219222
features.add(new ReportCommand(config));
220223
features.add(new BookmarksCommand(bookmarksSystem));
221-
features.add(new ChatGptCommand(chatGptService, helpSystemHelper));
224+
225+
features.add(new ChatGptCommand(chatGptService, helpSystemHelper,
226+
List.of(new WebSearchTool(config.getTavilyApiKey()), new FetchUrlTool())));
227+
222228
features.add(new JShellCommand(jshellEval));
223229
features.add(new MessageCommand());
224230
features.add(new RewriteCommand(chatGptService));

application/src/main/java/org/togetherjava/tjbot/features/chatgpt/ChatGptCommand.java

Lines changed: 101 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,33 @@
22

33
import com.github.benmanes.caffeine.cache.Cache;
44
import com.github.benmanes.caffeine.cache.Caffeine;
5+
import net.dv8tion.jda.api.EmbedBuilder;
56
import net.dv8tion.jda.api.entities.MessageEmbed;
67
import net.dv8tion.jda.api.entities.SelfUser;
8+
import net.dv8tion.jda.api.entities.channel.unions.MessageChannelUnion;
79
import net.dv8tion.jda.api.events.interaction.ModalInteractionEvent;
810
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
11+
import net.dv8tion.jda.api.interactions.InteractionHook;
12+
import net.dv8tion.jda.api.interactions.commands.OptionMapping;
13+
import net.dv8tion.jda.api.interactions.commands.OptionType;
914
import net.dv8tion.jda.api.interactions.components.text.TextInput;
1015
import net.dv8tion.jda.api.interactions.components.text.TextInputStyle;
1116
import net.dv8tion.jda.api.interactions.modals.Modal;
1217

1318
import org.togetherjava.tjbot.features.CommandVisibility;
1419
import org.togetherjava.tjbot.features.SlashCommandAdapter;
20+
import org.togetherjava.tjbot.features.chatgpt.tools.AiTool;
1521
import org.togetherjava.tjbot.features.help.HelpSystemHelper;
1622

1723
import java.time.Duration;
1824
import java.time.Instant;
1925
import java.time.temporal.ChronoUnit;
2026
import java.util.List;
21-
import java.util.Optional;
27+
import java.util.Locale;
28+
import java.util.Objects;
29+
import java.util.concurrent.CompletableFuture;
30+
import java.util.concurrent.Executor;
31+
import java.util.concurrent.Executors;
2232

2333
/**
2434
* The implemented command is {@code /chatgpt}, which allows users to ask ChatGPT a question, upon
@@ -27,12 +37,31 @@
2737
public final class ChatGptCommand extends SlashCommandAdapter {
2838
private static final ChatGptModel CHAT_GPT_MODEL = ChatGptModel.HIGH_QUALITY;
2939
public static final String COMMAND_NAME = "chatgpt";
40+
private static final String THINKING_OPTION = "enable_thinking";
3041
private static final String QUESTION_INPUT = "question";
3142
private static final int MAX_MESSAGE_INPUT_LENGTH = 200;
3243
private static final int MIN_MESSAGE_INPUT_LENGTH = 4;
3344
private static final Duration COMMAND_COOLDOWN = Duration.of(10, ChronoUnit.SECONDS);
45+
private static final String ERROR_RESPONSE = """
46+
An error has occurred while trying to communicate with ChatGPT.
47+
Please try again later.
48+
""";
49+
private static final String SYSTEM_PROMPT =
50+
"""
51+
You are a helpful assistant answering questions in a Discord server.
52+
Keep responses concise (no more than 280 words) and use markdown when helpful.
53+
When the user's question depends on live or external information, prefer calling \
54+
the `web_search` tool for current facts and the `fetch_url` tool to read a specific page.
55+
For code review questions, refer to the supplied code rather than rewriting it.""";
56+
3457
private final ChatGptService chatGptService;
3558
private final HelpSystemHelper helper;
59+
private final List<AiTool<?>> tools;
60+
private final Executor worker = Executors.newCachedThreadPool(runnable -> {
61+
Thread thread = new Thread(runnable, "chatgpt-worker");
62+
thread.setDaemon(true);
63+
return thread;
64+
});
3665

3766
private final Cache<String, Instant> userIdToAskedAtCache =
3867
Caffeine.newBuilder().maximumSize(1_000).expireAfterWrite(COMMAND_COOLDOWN).build();
@@ -42,17 +71,24 @@ public final class ChatGptCommand extends SlashCommandAdapter {
4271
*
4372
* @param chatGptService ChatGptService - Needed to make calls to ChatGPT API
4473
* @param helper HelpSystemHelper - Needed to generate response embed for prompt
74+
* @param tools tools the model may invoke while answering; pass an empty list to disable
4575
*/
46-
public ChatGptCommand(ChatGptService chatGptService, HelpSystemHelper helper) {
76+
public ChatGptCommand(ChatGptService chatGptService, HelpSystemHelper helper,
77+
List<AiTool<?>> tools) {
4778
super(COMMAND_NAME, "Ask the ChatGPT AI a question!", CommandVisibility.GUILD);
4879

4980
this.chatGptService = chatGptService;
5081
this.helper = helper;
82+
this.tools = tools;
83+
84+
getData().addOption(OptionType.BOOLEAN, THINKING_OPTION,
85+
"let the model use web tools (search/fetch) to answer", false);
5186
}
5287

5388
@Override
5489
public void onSlashCommand(SlashCommandInteractionEvent event) {
55-
Instant previousAskTime = userIdToAskedAtCache.getIfPresent(event.getMember().getId());
90+
Instant previousAskTime = userIdToAskedAtCache
91+
.getIfPresent(Objects.requireNonNull(event.getMember()).getId());
5692
if (previousAskTime != null) {
5793
long timeRemainingUntilNextAsk =
5894
COMMAND_COOLDOWN.minus(Duration.between(previousAskTime, Instant.now()))
@@ -66,40 +102,85 @@ public void onSlashCommand(SlashCommandInteractionEvent event) {
66102
return;
67103
}
68104

105+
OptionMapping thinkingOption = event.getOption(THINKING_OPTION);
106+
boolean thinkingEnabled = thinkingOption != null && thinkingOption.getAsBoolean();
107+
69108
TextInput body = TextInput
70109
.create(QUESTION_INPUT, "Ask ChatGPT a question or get help with code",
71110
TextInputStyle.PARAGRAPH)
72111
.setPlaceholder("Put your question for ChatGPT here")
73112
.setRequiredRange(MIN_MESSAGE_INPUT_LENGTH, MAX_MESSAGE_INPUT_LENGTH)
74113
.build();
75114

76-
Modal modal = Modal.create(generateComponentId(), "ChatGPT").addActionRow(body).build();
115+
Modal modal =
116+
Modal.create(generateComponentId(Boolean.toString(thinkingEnabled)), "ChatGPT")
117+
.addActionRow(body)
118+
.build();
77119
event.replyModal(modal).queue();
78120
}
79121

80122
@Override
81123
public void onModalSubmitted(ModalInteractionEvent event, List<String> args) {
82124
event.deferReply().queue();
83125

84-
String question = event.getValue(QUESTION_INPUT).getAsString();
85-
86-
Optional<String> chatgptResponse = chatGptService.ask(question,
87-
"You may use markdown syntax for the response", CHAT_GPT_MODEL);
88-
if (chatgptResponse.isPresent()) {
89-
userIdToAskedAtCache.put(event.getMember().getId(), Instant.now());
90-
}
91-
92-
String errorResponse = """
93-
An error has occurred while trying to communicate with ChatGPT.
94-
Please try again later.
95-
""";
96-
97-
String response = chatgptResponse.orElse(errorResponse);
126+
String question = Objects.requireNonNull(event.getValue(QUESTION_INPUT)).getAsString();
98127
SelfUser selfUser = event.getJDA().getSelfUser();
128+
InteractionHook hook = event.getHook();
129+
MessageChannelUnion channel = event.getChannel();
130+
String userId = Objects.requireNonNull(event.getMember()).getId();
131+
boolean thinkingEnabled = !args.isEmpty() && Boolean.parseBoolean(args.getFirst());
132+
List<AiTool<?>> activeTools = thinkingEnabled ? tools : List.<AiTool<?>>of();
133+
134+
ChatGptProgressEmbed progress = new ChatGptProgressEmbed(hook, selfUser, question);
135+
hook.editOriginalEmbeds(progress.initialEmbed()).queue();
136+
137+
Instant startedAt = Instant.now();
138+
CompletableFuture
139+
.supplyAsync(() -> chatGptService.askWithTools(question, CHAT_GPT_MODEL, activeTools,
140+
SYSTEM_PROMPT, progress), worker)
141+
.whenComplete((result, throwable) -> {
142+
String response;
143+
if (throwable == null && result.isPresent()) {
144+
response = result.get();
145+
userIdToAskedAtCache.put(userId, Instant.now());
146+
} else {
147+
response = ERROR_RESPONSE;
148+
}
149+
finishResponse(hook, channel, selfUser, question, response,
150+
Duration.between(startedAt, Instant.now()));
151+
});
152+
}
99153

100-
MessageEmbed responseEmbed =
154+
private void finishResponse(InteractionHook hook, MessageChannelUnion channel,
155+
SelfUser selfUser, String question, String response, Duration elapsed) {
156+
MessageEmbed baseEmbed =
101157
helper.generateGptResponseEmbed(response, selfUser, question, CHAT_GPT_MODEL);
158+
MessageEmbed finalEmbed = withTimingFooter(baseEmbed, elapsed);
159+
160+
channel.sendMessageEmbeds(finalEmbed).queue(_ -> hook.deleteOriginal().queue(null, _ -> {
161+
}), _ -> hook.deleteOriginal().queue(null, _ -> {
162+
}));
163+
}
102164

103-
event.getHook().sendMessageEmbeds(responseEmbed).queue();
165+
private static MessageEmbed withTimingFooter(MessageEmbed embed, Duration elapsed) {
166+
String existing = embed.getFooter() == null ? "" : embed.getFooter().getText();
167+
String suffix = "took %s".formatted(formatDuration(elapsed));
168+
String footer = existing == null || existing.isBlank() ? suffix
169+
: "%s · %s".formatted(existing, suffix);
170+
return new EmbedBuilder(embed).setFooter(footer).build();
171+
}
172+
173+
private static String formatDuration(Duration elapsed) {
174+
long totalMs = Math.max(0, elapsed.toMillis());
175+
if (totalMs < 1_000) {
176+
return totalMs + "ms";
177+
}
178+
long totalSeconds = totalMs / 1_000;
179+
if (totalSeconds < 60) {
180+
return String.format(Locale.ROOT, "%.1fs", totalMs / 1000.0);
181+
}
182+
long minutes = totalSeconds / 60;
183+
long seconds = totalSeconds % 60;
184+
return "%dm %ds".formatted(minutes, seconds);
104185
}
105186
}

0 commit comments

Comments
 (0)