|
1 | 1 | use crate::adapters::common; |
2 | | -use crate::adapters::llm::{LLMAdapter, LLMRequest, LLMResponse, ModelConfig, Usage}; |
| 2 | +use crate::adapters::llm::{ |
| 3 | + ChatRequest, ChatResponse, ContentBlock, LLMAdapter, LLMRequest, LLMResponse, ModelConfig, |
| 4 | + StopReason, Usage, |
| 5 | +}; |
3 | 6 | use anyhow::{Context, Result}; |
4 | 7 | use async_trait::async_trait; |
5 | 8 | use reqwest::Client; |
@@ -48,6 +51,59 @@ struct AnthropicUsage { |
48 | 51 | output_tokens: usize, |
49 | 52 | } |
50 | 53 |
|
| 54 | +// === Chat API types (for tool use) === |
| 55 | + |
| 56 | +#[derive(Serialize)] |
| 57 | +struct AnthropicChatRequest { |
| 58 | + model: String, |
| 59 | + messages: Vec<AnthropicChatMessage>, |
| 60 | + max_tokens: usize, |
| 61 | + temperature: f32, |
| 62 | + system: String, |
| 63 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 64 | + tools: Option<Vec<AnthropicToolDef>>, |
| 65 | +} |
| 66 | + |
| 67 | +#[derive(Serialize, Deserialize, Clone)] |
| 68 | +struct AnthropicChatMessage { |
| 69 | + role: String, |
| 70 | + content: Vec<AnthropicContentBlock>, |
| 71 | +} |
| 72 | + |
| 73 | +#[derive(Serialize, Deserialize, Clone)] |
| 74 | +#[serde(tag = "type", rename_all = "snake_case")] |
| 75 | +enum AnthropicContentBlock { |
| 76 | + Text { |
| 77 | + text: String, |
| 78 | + }, |
| 79 | + ToolUse { |
| 80 | + id: String, |
| 81 | + name: String, |
| 82 | + input: serde_json::Value, |
| 83 | + }, |
| 84 | + ToolResult { |
| 85 | + tool_use_id: String, |
| 86 | + content: String, |
| 87 | + #[serde(default, skip_serializing_if = "std::ops::Not::not")] |
| 88 | + is_error: bool, |
| 89 | + }, |
| 90 | +} |
| 91 | + |
| 92 | +#[derive(Serialize)] |
| 93 | +struct AnthropicToolDef { |
| 94 | + name: String, |
| 95 | + description: String, |
| 96 | + input_schema: serde_json::Value, |
| 97 | +} |
| 98 | + |
| 99 | +#[derive(Deserialize)] |
| 100 | +struct AnthropicChatResponse { |
| 101 | + content: Vec<AnthropicContentBlock>, |
| 102 | + model: String, |
| 103 | + usage: AnthropicUsage, |
| 104 | + stop_reason: String, |
| 105 | +} |
| 106 | + |
51 | 107 | impl AnthropicAdapter { |
52 | 108 | pub fn new(config: ModelConfig) -> Result<Self> { |
53 | 109 | let base_url = config |
@@ -149,12 +205,135 @@ impl LLMAdapter for AnthropicAdapter { |
149 | 205 | fn model_name(&self) -> &str { |
150 | 206 | &self.config.model_name |
151 | 207 | } |
| 208 | + |
| 209 | + async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> { |
| 210 | + let messages: Vec<AnthropicChatMessage> = request |
| 211 | + .messages |
| 212 | + .iter() |
| 213 | + .map(|m| AnthropicChatMessage { |
| 214 | + role: m.role.to_string(), |
| 215 | + content: m |
| 216 | + .content |
| 217 | + .iter() |
| 218 | + .map(|b| match b { |
| 219 | + ContentBlock::Text { text } => { |
| 220 | + AnthropicContentBlock::Text { text: text.clone() } |
| 221 | + } |
| 222 | + ContentBlock::ToolUse { id, name, input } => { |
| 223 | + AnthropicContentBlock::ToolUse { |
| 224 | + id: id.clone(), |
| 225 | + name: name.clone(), |
| 226 | + input: input.clone(), |
| 227 | + } |
| 228 | + } |
| 229 | + ContentBlock::ToolResult { |
| 230 | + tool_use_id, |
| 231 | + content, |
| 232 | + is_error, |
| 233 | + } => AnthropicContentBlock::ToolResult { |
| 234 | + tool_use_id: tool_use_id.clone(), |
| 235 | + content: content.clone(), |
| 236 | + is_error: *is_error, |
| 237 | + }, |
| 238 | + }) |
| 239 | + .collect(), |
| 240 | + }) |
| 241 | + .collect(); |
| 242 | + |
| 243 | + let tools: Option<Vec<AnthropicToolDef>> = if request.tools.is_empty() { |
| 244 | + None |
| 245 | + } else { |
| 246 | + Some( |
| 247 | + request |
| 248 | + .tools |
| 249 | + .iter() |
| 250 | + .map(|t| AnthropicToolDef { |
| 251 | + name: t.name.clone(), |
| 252 | + description: t.description.clone(), |
| 253 | + input_schema: t.input_schema.clone(), |
| 254 | + }) |
| 255 | + .collect(), |
| 256 | + ) |
| 257 | + }; |
| 258 | + |
| 259 | + let anthropic_request = AnthropicChatRequest { |
| 260 | + model: self.config.model_name.clone(), |
| 261 | + messages, |
| 262 | + max_tokens: request.max_tokens.unwrap_or(self.config.max_tokens), |
| 263 | + temperature: request.temperature.unwrap_or(self.config.temperature), |
| 264 | + system: request.system_prompt, |
| 265 | + tools, |
| 266 | + }; |
| 267 | + |
| 268 | + let url = format!("{}/messages", self.base_url); |
| 269 | + let response = common::send_with_retry_config("Anthropic", &self.retry_config, &mut || { |
| 270 | + self.client |
| 271 | + .post(&url) |
| 272 | + .header("x-api-key", &self.api_key) |
| 273 | + .header("anthropic-version", "2023-06-01") |
| 274 | + .header("Content-Type", "application/json") |
| 275 | + .json(&anthropic_request) |
| 276 | + }) |
| 277 | + .await |
| 278 | + .context("Failed to send chat request to Anthropic")?; |
| 279 | + |
| 280 | + let anthropic_response: AnthropicChatResponse = response |
| 281 | + .json() |
| 282 | + .await |
| 283 | + .context("Failed to parse Anthropic chat response")?; |
| 284 | + |
| 285 | + let content: Vec<ContentBlock> = anthropic_response |
| 286 | + .content |
| 287 | + .into_iter() |
| 288 | + .map(|b| match b { |
| 289 | + AnthropicContentBlock::Text { text } => ContentBlock::Text { text }, |
| 290 | + AnthropicContentBlock::ToolUse { id, name, input } => { |
| 291 | + ContentBlock::ToolUse { id, name, input } |
| 292 | + } |
| 293 | + AnthropicContentBlock::ToolResult { |
| 294 | + tool_use_id, |
| 295 | + content, |
| 296 | + is_error, |
| 297 | + } => ContentBlock::ToolResult { |
| 298 | + tool_use_id, |
| 299 | + content, |
| 300 | + is_error, |
| 301 | + }, |
| 302 | + }) |
| 303 | + .collect(); |
| 304 | + |
| 305 | + let stop_reason = match anthropic_response.stop_reason.as_str() { |
| 306 | + "end_turn" => StopReason::EndTurn, |
| 307 | + "tool_use" => StopReason::ToolUse, |
| 308 | + "max_tokens" => StopReason::MaxTokens, |
| 309 | + _ => StopReason::Other, |
| 310 | + }; |
| 311 | + |
| 312 | + Ok(ChatResponse { |
| 313 | + content, |
| 314 | + model: anthropic_response.model, |
| 315 | + usage: Some(Usage { |
| 316 | + prompt_tokens: anthropic_response.usage.input_tokens, |
| 317 | + completion_tokens: anthropic_response.usage.output_tokens, |
| 318 | + total_tokens: anthropic_response.usage.input_tokens |
| 319 | + + anthropic_response.usage.output_tokens, |
| 320 | + }), |
| 321 | + stop_reason, |
| 322 | + }) |
| 323 | + } |
| 324 | + |
| 325 | + fn supports_tools(&self) -> bool { |
| 326 | + true |
| 327 | + } |
152 | 328 | } |
153 | 329 |
|
154 | 330 | #[cfg(test)] |
155 | 331 | mod tests { |
156 | 332 | use super::*; |
157 | | - use crate::adapters::llm::{LLMAdapter, LLMRequest, ModelConfig}; |
| 333 | + use crate::adapters::llm::{ |
| 334 | + ChatMessage, ChatRequest, ChatRole, ContentBlock as CB, LLMAdapter, LLMRequest, |
| 335 | + ModelConfig, StopReason, ToolDefinition, |
| 336 | + }; |
158 | 337 |
|
159 | 338 | fn test_config(base_url: &str) -> ModelConfig { |
160 | 339 | ModelConfig { |
@@ -419,4 +598,124 @@ mod tests { |
419 | 598 |
|
420 | 599 | assert!(result.is_ok()); |
421 | 600 | } |
| 601 | + |
| 602 | + #[test] |
| 603 | + fn test_supports_tools() { |
| 604 | + let config = test_config("http://localhost:8080"); |
| 605 | + let adapter = AnthropicAdapter::new(config).unwrap(); |
| 606 | + assert!(adapter.supports_tools()); |
| 607 | + } |
| 608 | + |
| 609 | + fn make_chat_request() -> ChatRequest { |
| 610 | + ChatRequest { |
| 611 | + system_prompt: "You are a code reviewer.".to_string(), |
| 612 | + messages: vec![ChatMessage { |
| 613 | + role: ChatRole::User, |
| 614 | + content: vec![CB::Text { |
| 615 | + text: "Review this.".to_string(), |
| 616 | + }], |
| 617 | + }], |
| 618 | + tools: vec![ToolDefinition { |
| 619 | + name: "read_file".to_string(), |
| 620 | + description: "Read a file".to_string(), |
| 621 | + input_schema: serde_json::json!({ |
| 622 | + "type": "object", |
| 623 | + "properties": {"file_path": {"type": "string"}}, |
| 624 | + "required": ["file_path"] |
| 625 | + }), |
| 626 | + }], |
| 627 | + temperature: None, |
| 628 | + max_tokens: None, |
| 629 | + } |
| 630 | + } |
| 631 | + |
| 632 | + #[tokio::test] |
| 633 | + async fn test_chat_end_turn() { |
| 634 | + let mut server = mockito::Server::new_async().await; |
| 635 | + let _mock = server |
| 636 | + .mock("POST", "/messages") |
| 637 | + .with_status(200) |
| 638 | + .with_header("content-type", "application/json") |
| 639 | + .with_body( |
| 640 | + r#"{ |
| 641 | + "content": [{"type": "text", "text": "LGTM, no issues found."}], |
| 642 | + "model": "claude-3-5-sonnet-20241022", |
| 643 | + "usage": {"input_tokens": 100, "output_tokens": 20}, |
| 644 | + "stop_reason": "end_turn" |
| 645 | + }"#, |
| 646 | + ) |
| 647 | + .create_async() |
| 648 | + .await; |
| 649 | + |
| 650 | + let adapter = AnthropicAdapter::new(test_config(&server.url())).unwrap(); |
| 651 | + let result = adapter.chat(make_chat_request()).await.unwrap(); |
| 652 | + |
| 653 | + assert_eq!(result.stop_reason, StopReason::EndTurn); |
| 654 | + assert_eq!(result.content.len(), 1); |
| 655 | + match &result.content[0] { |
| 656 | + CB::Text { text } => assert_eq!(text, "LGTM, no issues found."), |
| 657 | + _ => panic!("Expected text block"), |
| 658 | + } |
| 659 | + assert_eq!(result.usage.unwrap().total_tokens, 120); |
| 660 | + } |
| 661 | + |
| 662 | + #[tokio::test] |
| 663 | + async fn test_chat_tool_use_response() { |
| 664 | + let mut server = mockito::Server::new_async().await; |
| 665 | + let _mock = server |
| 666 | + .mock("POST", "/messages") |
| 667 | + .with_status(200) |
| 668 | + .with_header("content-type", "application/json") |
| 669 | + .with_body( |
| 670 | + r#"{ |
| 671 | + "content": [ |
| 672 | + {"type": "text", "text": "Let me check that file."}, |
| 673 | + {"type": "tool_use", "id": "toolu_01", "name": "read_file", "input": {"file_path": "src/main.rs"}} |
| 674 | + ], |
| 675 | + "model": "claude-3-5-sonnet-20241022", |
| 676 | + "usage": {"input_tokens": 100, "output_tokens": 30}, |
| 677 | + "stop_reason": "tool_use" |
| 678 | + }"#, |
| 679 | + ) |
| 680 | + .create_async() |
| 681 | + .await; |
| 682 | + |
| 683 | + let adapter = AnthropicAdapter::new(test_config(&server.url())).unwrap(); |
| 684 | + let result = adapter.chat(make_chat_request()).await.unwrap(); |
| 685 | + |
| 686 | + assert_eq!(result.stop_reason, StopReason::ToolUse); |
| 687 | + assert_eq!(result.content.len(), 2); |
| 688 | + match &result.content[1] { |
| 689 | + CB::ToolUse { id, name, input } => { |
| 690 | + assert_eq!(id, "toolu_01"); |
| 691 | + assert_eq!(name, "read_file"); |
| 692 | + assert_eq!(input["file_path"], "src/main.rs"); |
| 693 | + } |
| 694 | + _ => panic!("Expected ToolUse block"), |
| 695 | + } |
| 696 | + } |
| 697 | + |
| 698 | + #[tokio::test] |
| 699 | + async fn test_chat_max_tokens_stop_reason() { |
| 700 | + let mut server = mockito::Server::new_async().await; |
| 701 | + let _mock = server |
| 702 | + .mock("POST", "/messages") |
| 703 | + .with_status(200) |
| 704 | + .with_header("content-type", "application/json") |
| 705 | + .with_body( |
| 706 | + r#"{ |
| 707 | + "content": [{"type": "text", "text": "Partial response..."}], |
| 708 | + "model": "claude-3-5-sonnet-20241022", |
| 709 | + "usage": {"input_tokens": 100, "output_tokens": 100}, |
| 710 | + "stop_reason": "max_tokens" |
| 711 | + }"#, |
| 712 | + ) |
| 713 | + .create_async() |
| 714 | + .await; |
| 715 | + |
| 716 | + let adapter = AnthropicAdapter::new(test_config(&server.url())).unwrap(); |
| 717 | + let result = adapter.chat(make_chat_request()).await.unwrap(); |
| 718 | + |
| 719 | + assert_eq!(result.stop_reason, StopReason::MaxTokens); |
| 720 | + } |
422 | 721 | } |
0 commit comments