|
1 | | -use anyhow::Result; |
2 | | -use reqwest::Client; |
3 | | -use serde_json::Value; |
4 | | - |
5 | | -pub(in super::super) async fn test_model_inference( |
6 | | - client: &Client, |
7 | | - base_url: &str, |
8 | | - model_name: &str, |
9 | | - endpoint_type: &str, |
10 | | -) -> Result<String> { |
11 | | - let system_msg = "You are a code reviewer. Respond with a single JSON object."; |
12 | | - let user_msg = |
13 | | - "Review this code change:\n+fn add(a: i32, b: i32) -> i32 { a + b }\nRespond with: {\"ok\": true}"; |
14 | | - |
15 | | - let messages = serde_json::json!([ |
16 | | - {"role": "system", "content": system_msg}, |
17 | | - {"role": "user", "content": user_msg} |
18 | | - ]); |
19 | | - |
20 | | - if endpoint_type == "ollama" { |
21 | | - let url = format!("{}/api/chat", base_url); |
22 | | - let body = serde_json::json!({ |
23 | | - "model": model_name, |
24 | | - "messages": messages, |
25 | | - "stream": false, |
26 | | - "options": {"num_predict": 50} |
27 | | - }); |
28 | | - |
29 | | - let resp = client |
30 | | - .post(&url) |
31 | | - .json(&body) |
32 | | - .send() |
33 | | - .await |
34 | | - .map_err(|e| anyhow::anyhow!("Request failed: {}", e))?; |
35 | | - |
36 | | - if !resp.status().is_success() { |
37 | | - let status = resp.status(); |
38 | | - let body = resp.text().await.unwrap_or_default(); |
39 | | - anyhow::bail!("HTTP {} - {}", status, body); |
40 | | - } |
41 | | - |
42 | | - let text = resp.text().await?; |
43 | | - parse_ollama_response_content(&text) |
44 | | - } else { |
45 | | - let url = format!("{}/v1/chat/completions", base_url); |
46 | | - let body = serde_json::json!({ |
47 | | - "model": model_name, |
48 | | - "messages": messages, |
49 | | - "max_tokens": 50, |
50 | | - "temperature": 0.1 |
51 | | - }); |
52 | | - |
53 | | - let resp = client |
54 | | - .post(&url) |
55 | | - .json(&body) |
56 | | - .send() |
57 | | - .await |
58 | | - .map_err(|e| anyhow::anyhow!("Request failed: {}", e))?; |
59 | | - |
60 | | - if !resp.status().is_success() { |
61 | | - let status = resp.status(); |
62 | | - let body = resp.text().await.unwrap_or_default(); |
63 | | - anyhow::bail!("HTTP {} - {}", status, body); |
64 | | - } |
65 | | - |
66 | | - let text = resp.text().await?; |
67 | | - parse_openai_response_content(&text) |
68 | | - } |
69 | | -} |
70 | | - |
71 | | -pub(in super::super) fn estimate_tokens(text: &str) -> usize { |
72 | | - (text.len() / 4).max(1) |
73 | | -} |
74 | | - |
75 | | -fn parse_ollama_response_content(text: &str) -> Result<String> { |
76 | | - let value: Value = serde_json::from_str(text)?; |
77 | | - Ok(value |
78 | | - .get("message") |
79 | | - .and_then(|message| message.get("content")) |
80 | | - .and_then(|content| content.as_str()) |
81 | | - .unwrap_or("") |
82 | | - .to_string()) |
83 | | -} |
84 | | - |
85 | | -fn parse_openai_response_content(text: &str) -> Result<String> { |
86 | | - let value: Value = serde_json::from_str(text)?; |
87 | | - Ok(value |
88 | | - .get("choices") |
89 | | - .and_then(|choices| choices.as_array()) |
90 | | - .and_then(|choices| choices.first()) |
91 | | - .and_then(|choice| choice.get("message")) |
92 | | - .and_then(|message| message.get("content")) |
93 | | - .and_then(|content| content.as_str()) |
94 | | - .unwrap_or("") |
95 | | - .to_string()) |
96 | | -} |
97 | | - |
98 | | -#[cfg(test)] |
99 | | -mod tests { |
100 | | - use super::*; |
101 | | - |
102 | | - #[test] |
103 | | - fn test_estimate_tokens() { |
104 | | - assert_eq!(estimate_tokens(""), 1); |
105 | | - assert_eq!(estimate_tokens("abcd"), 1); |
106 | | - assert_eq!(estimate_tokens("abcdefgh"), 2); |
107 | | - assert_eq!(estimate_tokens("a]"), 1); |
108 | | - } |
109 | | - |
110 | | - #[test] |
111 | | - fn test_estimate_tokens_longer_text() { |
112 | | - let text = "This is a longer response with several words in it for testing."; |
113 | | - let tokens = estimate_tokens(text); |
114 | | - assert!(tokens > 10); |
115 | | - assert!(tokens < 30); |
116 | | - } |
117 | | - |
118 | | - #[test] |
119 | | - fn test_test_model_inference_ollama_parse() { |
120 | | - let json = r#"{"message":{"role":"assistant","content":"{\"ok\": true}"}}"#; |
121 | | - let content = parse_ollama_response_content(json).unwrap(); |
122 | | - assert_eq!(content, "{\"ok\": true}"); |
123 | | - } |
124 | | - |
125 | | - #[test] |
126 | | - fn test_test_model_inference_openai_parse() { |
127 | | - let json = r#"{"choices":[{"message":{"content":"{\"ok\": true}"}}]}"#; |
128 | | - let content = parse_openai_response_content(json).unwrap(); |
129 | | - assert_eq!(content, "{\"ok\": true}"); |
130 | | - } |
131 | | - |
132 | | - #[test] |
133 | | - fn test_test_model_inference_empty_choices() { |
134 | | - let json = r#"{"choices":[]}"#; |
135 | | - let content = parse_openai_response_content(json).unwrap(); |
136 | | - assert_eq!(content, ""); |
137 | | - } |
138 | | -} |
| 1 | +#[path = "inference/request.rs"] |
| 2 | +mod request; |
| 3 | +#[path = "inference/response.rs"] |
| 4 | +mod response; |
| 5 | +#[path = "inference/run.rs"] |
| 6 | +mod run; |
| 7 | + |
| 8 | +pub(in super::super) use run::{estimate_tokens, test_model_inference}; |
0 commit comments