|
5 | 5 | import dev.bluetree242.serverassistantai.api.config.option.OptionMap; |
6 | 6 | import dev.bluetree242.serverassistantai.api.registry.chatmodel.ChatModelContext; |
7 | 7 | import dev.bluetree242.serverassistantai.api.registry.chatmodel.ChatModelProvider; |
8 | | -import dev.langchain4j.data.message.AiMessage; |
9 | | -import dev.langchain4j.data.message.ChatMessage; |
10 | 8 | import dev.langchain4j.data.message.SystemMessage; |
11 | 9 | import dev.langchain4j.data.message.UserMessage; |
12 | 10 | import dev.langchain4j.model.anthropic.AnthropicChatModel; |
13 | | -import dev.langchain4j.model.chat.ChatLanguageModel; |
14 | | -import dev.langchain4j.model.output.Response; |
| 11 | +import dev.langchain4j.model.chat.ChatModel; |
| 12 | +import dev.langchain4j.model.chat.request.ChatRequest; |
| 13 | +import dev.langchain4j.model.chat.response.ChatResponse; |
15 | 14 | import lombok.RequiredArgsConstructor; |
16 | 15 | import org.jetbrains.annotations.NotNull; |
17 | 16 | import org.jetbrains.annotations.Nullable; |
18 | 17 |
|
19 | 18 | import java.time.Duration; |
20 | 19 | import java.util.Collections; |
21 | | -import java.util.List; |
22 | 20 | import java.util.Map; |
23 | 21 | import java.util.stream.Collectors; |
24 | 22 |
|
@@ -73,10 +71,21 @@ public String getDisplayName(@Nullable ChatModelContext context) { |
73 | 71 | return "Anthropic"; |
74 | 72 | } |
75 | 73 |
|
76 | | - public record AnthropicWrapper(AnthropicChatModel model) implements ChatLanguageModel { |
| 74 | + public record AnthropicWrapper(AnthropicChatModel model) implements ChatModel { |
77 | 75 | @Override |
78 | | - public Response<AiMessage> generate(List<ChatMessage> messages) { |
79 | | - return model.generate(UserMessage.userMessage(messages.stream().filter(m -> m instanceof SystemMessage).map(m -> (SystemMessage) m).map(SystemMessage::text).collect(Collectors.joining("\n\n\n\n")))); |
| 76 | + public ChatResponse chat(ChatRequest request) { |
| 77 | + return model.chat(new ChatRequest.Builder() |
| 78 | + .modelName(request.modelName()) |
| 79 | + .maxOutputTokens(request.maxOutputTokens()) |
| 80 | + .temperature(request.temperature()) |
| 81 | + .stopSequences(request.stopSequences()) |
| 82 | + .parameters(request.parameters()) |
| 83 | + .topK(request.topK()) |
| 84 | + .toolChoice(request.toolChoice()) |
| 85 | + .toolSpecifications(request.toolSpecifications()) |
| 86 | + .responseFormat(request.responseFormat()) |
| 87 | + .messages(UserMessage.userMessage(request.messages().stream().filter(m -> m instanceof SystemMessage).map(m -> (SystemMessage) m).map(SystemMessage::text).collect(Collectors.joining("\n\n\n\n")))) |
| 88 | + .build()); |
80 | 89 | } |
81 | 90 | } |
82 | 91 | } |
0 commit comments