|
1 | 1 | package dev.bluetree242.saaiaddons.aistudio; |
2 | 2 |
|
3 | | -import dev.bluetree242.saaiaddons.aistudio.api.AiStudioChatLanguageModel; |
4 | 3 | import dev.bluetree242.serverassistantai.api.ServerAssistantAIAPI; |
5 | 4 | import dev.bluetree242.serverassistantai.api.config.option.OptionMap; |
6 | 5 | import dev.bluetree242.serverassistantai.api.registry.chatmodel.ChatModelContext; |
7 | 6 | import dev.bluetree242.serverassistantai.api.registry.chatmodel.ChatModelProvider; |
| 7 | +import dev.langchain4j.data.message.AiMessage; |
| 8 | +import dev.langchain4j.data.message.ChatMessage; |
| 9 | +import dev.langchain4j.data.message.SystemMessage; |
| 10 | +import dev.langchain4j.data.message.UserMessage; |
| 11 | +import dev.langchain4j.model.chat.ChatLanguageModel; |
| 12 | +import dev.langchain4j.model.googleai.GeminiHarmBlockThreshold; |
| 13 | +import dev.langchain4j.model.googleai.GeminiHarmCategory; |
| 14 | +import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel; |
| 15 | +import dev.langchain4j.model.output.Response; |
8 | 16 | import lombok.RequiredArgsConstructor; |
9 | 17 | import org.jetbrains.annotations.NotNull; |
10 | 18 | import org.jetbrains.annotations.Nullable; |
11 | 19 |
|
12 | 20 | import java.time.Duration; |
13 | 21 | import java.util.Collections; |
| 22 | +import java.util.HashMap; |
| 23 | +import java.util.List; |
14 | 24 | import java.util.Map; |
| 25 | +import java.util.stream.Collectors; |
15 | 26 |
|
16 | 27 | @RequiredArgsConstructor |
17 | | -public class GoogleAiStudioChatModelProvider implements ChatModelProvider<AiStudioChatLanguageModel> { |
| 28 | +public class GoogleAiStudioChatModelProvider implements ChatModelProvider<GoogleAiStudioChatModelProvider.GoogleAiWrapper> { |
| 29 | + private static final Map<GeminiHarmCategory, GeminiHarmBlockThreshold> safetySettings; |
| 30 | + |
| 31 | + static { |
| 32 | + Map<GeminiHarmCategory, GeminiHarmBlockThreshold> result = new HashMap<>(); |
| 33 | + for (GeminiHarmCategory value : GeminiHarmCategory.values()) { |
| 34 | + result.put(value, GeminiHarmBlockThreshold.BLOCK_NONE); |
| 35 | + } |
| 36 | + safetySettings = Collections.unmodifiableMap(result); |
| 37 | + } |
| 38 | + |
18 | 39 | private final ServerAssistantAIAPI api; |
19 | 40 |
|
20 | 41 | @NotNull |
21 | 42 | @Override |
22 | | - public AiStudioChatLanguageModel provide(@NotNull ChatModelContext context) { |
| 43 | + public GoogleAiWrapper provide(@NotNull ChatModelContext context) { |
23 | 44 | OptionMap options = context.options(); |
24 | 45 | String model = options.getString("model"); |
25 | | - //noinspection unchecked |
26 | 46 | if (model.isBlank()) throw new IllegalStateException("Please set the model for Google AI Studio chat model."); |
27 | | - return AiStudioChatLanguageModel.builder() |
28 | | - .model(model) |
| 47 | + return new GoogleAiWrapper(GoogleAiGeminiChatModel.builder() |
| 48 | + .safetySettings(safetySettings) |
| 49 | + .modelName(model) |
29 | 50 | .timeout(Duration.ofSeconds(options.getLong("timeout"))) |
30 | 51 | .maxOutputTokens(options.getIntegerOrDefault("max_output_tokens", i -> i == 0, null)) |
31 | 52 | .stopSequences(options.getList("stop").getStringList()) |
32 | 53 | .apiKey(api.getCredentialsRegistry().getConfigured(GoogleAiStudioAddon.NAME, GoogleAiStudioCredentialsLoader.class)) |
33 | | - .build(); |
| 54 | + .build()); |
34 | 55 | } |
35 | 56 |
|
36 | 57 | @NotNull |
@@ -58,4 +79,11 @@ public Map<String, Object> export(@NotNull ChatModelContext context) { |
58 | 79 | public String getDisplayName(@Nullable ChatModelContext context) { |
59 | 80 | return "Google AI Studio"; |
60 | 81 | } |
| 82 | + |
| 83 | + public record GoogleAiWrapper(GoogleAiGeminiChatModel model) implements ChatLanguageModel { |
| 84 | + @Override |
| 85 | + public Response<AiMessage> generate(List<ChatMessage> messages) { |
| 86 | + return model.generate(UserMessage.userMessage(messages.stream().filter(m -> m instanceof SystemMessage).map(m -> (SystemMessage) m).map(SystemMessage::text).collect(Collectors.joining("\n\n\n\n")))); |
| 87 | + } |
| 88 | + } |
61 | 89 | } |
0 commit comments