-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathAIResponseParser.java
More file actions
82 lines (70 loc) · 3.42 KB
/
AIResponseParser.java
File metadata and controls
82 lines (70 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package com.togetherjava.tjplays.services.chatgpt;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* Represents a class to partition long text blocks into smaller blocks which work with Discord's
* API. Initially constructed to partition text from AI text generation APIs.
*/
public class AIResponseParser {
private AIResponseParser() {
throw new UnsupportedOperationException("Utility class, construction not supported");
}
private static final Logger logger = LoggerFactory.getLogger(AIResponseParser.class);
private static final int RESPONSE_LENGTH_LIMIT = 2_000;
/**
* Parses the response generated by AI. If response is longer than
* {@value RESPONSE_LENGTH_LIMIT}, then breaks apart the response into suitable lengths for
* Discords API.
*
* @param response The response from the AI which we want to send over Discord.
* @return An array potentially holding the original response split up into shorter than
* {@value RESPONSE_LENGTH_LIMIT} length pieces.
*/
public static String[] parse(String response) {
String[] partedResponse = new String[] {response};
if (response.length() > RESPONSE_LENGTH_LIMIT) {
logger.debug("Response to parse:\n{}", response);
partedResponse = partitionAiResponse(response);
}
return partedResponse;
}
private static String[] partitionAiResponse(String response) {
List<String> responseChunks = new ArrayList<>();
String[] splitResponseOnMarks = response.split("```");
for (int i = 0; i < splitResponseOnMarks.length; i++) {
String split = splitResponseOnMarks[i];
List<String> chunks = new ArrayList<>();
chunks.add(split);
// Check each chunk for correct length. If over the length, split in two and check
// again.
while (!chunks.stream().allMatch(s -> s.length() < RESPONSE_LENGTH_LIMIT)) {
for (int j = 0; j < chunks.size(); j++) {
String chunk = chunks.get(j);
if (chunk.length() > RESPONSE_LENGTH_LIMIT) {
int midpointNewline = chunk.lastIndexOf("\n", chunk.length() / 2);
chunks.set(j, chunk.substring(0, midpointNewline));
chunks.add(j + 1, chunk.substring(midpointNewline));
}
}
}
// Given the splitting on ```, the odd numbered entries need to have code marks
// restored.
if (i % 2 != 0) {
// We assume that everything after the ``` on the same line is the language
// declaration. Could be empty.
String lang = split.substring(0, split.indexOf(System.lineSeparator()));
chunks = chunks.stream()
.map(s -> ("```" + lang).concat(s).concat("```"))
// Handle case of doubling language declaration
.map(s -> s.replaceFirst("```" + lang + lang, "```" + lang))
.collect(Collectors.toList());
}
List<String> list = chunks.stream().filter(string -> !string.equals("")).toList();
responseChunks.addAll(list);
} // end of for loop.
return responseChunks.toArray(new String[0]);
}
}