diff --git a/crates/tabby-common/src/api/code.rs b/crates/tabby-common/src/api/code.rs index dfdc9b1a00d4..abee9d44e40b 100644 --- a/crates/tabby-common/src/api/code.rs +++ b/crates/tabby-common/src/api/code.rs @@ -115,6 +115,15 @@ pub trait CodeSearch: Send + Sync { ) -> Result; } +#[async_trait] +pub trait WarpGrepSearch: Send + Sync { + async fn search( + &self, + repo_dir: &std::path::Path, + query: &str, + ) -> Result; +} + /// Normalize the path form different platform to unix style path pub fn normalize_to_unix_path(filepath: &str) -> String { filepath.replace('\\', "/") diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index 1641473ce73a..4b2c0b2a948d 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -218,6 +218,7 @@ pub async fn main(config: &Config, args: &ServeArgs) { completion_stream, docsearch, |x| Box::new(services::structured_doc::create_serper(x)), + |x| Box::new(services::warpgrep::create(x)), ) .await; api = new_api; diff --git a/crates/tabby/src/services/mod.rs b/crates/tabby/src/services/mod.rs index 2d572751a124..8aee82c251f1 100644 --- a/crates/tabby/src/services/mod.rs +++ b/crates/tabby/src/services/mod.rs @@ -6,3 +6,4 @@ pub mod health; pub mod model; pub mod structured_doc; pub mod tantivy; +pub mod warpgrep; diff --git a/crates/tabby/src/services/warpgrep/mod.rs b/crates/tabby/src/services/warpgrep/mod.rs new file mode 100644 index 000000000000..37f0e046c100 --- /dev/null +++ b/crates/tabby/src/services/warpgrep/mod.rs @@ -0,0 +1,335 @@ +mod tool_executor; + +use std::fs; +use std::path::Path; + +use anyhow::Result; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use tabby_common::api::code::{ + CodeSearchDocument, CodeSearchError, CodeSearchHit, CodeSearchResponse, CodeSearchScores, + WarpGrepSearch, +}; +use tracing::{debug, warn}; + +use self::tool_executor::{ + execute_list_directory, execute_read, execute_ripgrep, format_tool_responses, + parse_finish_output, parse_tool_calls, +}; + +const DEFAULT_ENDPOINT: &str = "https://api.morphllm.com/v1/chat/completions"; +const DEFAULT_MODEL: &str = "morph-warp-grep-v2"; +const MAX_TURNS: usize = 4; + +pub struct WarpGrepService { + client: Client, + api_endpoint: String, + model: String, +} + +pub fn create(api_key: &str) -> impl WarpGrepSearch { + let endpoint = + std::env::var("MORPH_API_ENDPOINT").unwrap_or_else(|_| DEFAULT_ENDPOINT.to_string()); + let model = std::env::var("MORPH_MODEL").unwrap_or_else(|_| DEFAULT_MODEL.to_string()); + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::AUTHORIZATION, + reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}")) + .expect("Invalid API key format"), + ); + + let client = Client::builder() + .default_headers(headers) + .timeout(std::time::Duration::from_secs(30)) + .build() + .expect("Failed to build HTTP client"); + + WarpGrepService { + client, + api_endpoint: endpoint, + model, + } +} + +#[derive(Serialize)] +struct ChatRequest { + model: String, + messages: Vec, + temperature: f32, + max_tokens: u32, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +struct Message { + role: String, + content: String, +} + +#[derive(Deserialize)] +struct ChatResponse { + choices: Vec, +} + +#[derive(Deserialize)] +struct Choice { + message: Message, +} + +impl WarpGrepService { + async fn search_impl(&self, repo_dir: &Path, query: &str) -> Result { + let file_tree = build_file_tree(repo_dir)?; + let initial_msg = format!( + "{file_tree}\n{query}" + ); + + let mut messages = vec![Message { + role: "user".to_string(), + content: initial_msg, + }]; + + for turn in 0..MAX_TURNS { + let request = ChatRequest { + model: self.model.clone(), + messages: messages.clone(), + temperature: 0.0, + max_tokens: 2048, + }; + + let resp = self + .client + .post(&self.api_endpoint) + .json(&request) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("WarpGrep API error {status}: {body}"); + } + + let chat_resp: ChatResponse = resp.json().await?; + let assistant_msg = chat_resp + .choices + .into_iter() + .next() + .ok_or_else(|| anyhow::anyhow!("No choices in WarpGrep response"))? + .message; + + let preview_len = assistant_msg.content.len().min(200); + debug!( + "WarpGrep turn {}: {}", + turn + 1, + &assistant_msg.content[..preview_len] + ); + + messages.push(assistant_msg.clone()); + + let tool_calls = parse_tool_calls(&assistant_msg.content); + + // Check for finish + if tool_calls.iter().any(|tc| tc.name == "finish") { + return parse_finish_to_response(repo_dir, &assistant_msg.content); + } + + if tool_calls.is_empty() { + debug!("WarpGrep: no tool calls found in response, ending"); + break; + } + + // Execute tool calls + let mut results = Vec::new(); + for tc in &tool_calls { + if tc.name == "finish" { + continue; + } + let result = match tc.name.as_str() { + "ripgrep" => { + let pattern = tc.params.get("pattern").map(|s| s.as_str()).unwrap_or(""); + let path = tc.params.get("path").map(|s| s.as_str()); + execute_ripgrep(repo_dir, pattern, path) + } + "read" => { + let path = tc.params.get("path").map(|s| s.as_str()).unwrap_or(""); + let lines = tc.params.get("lines").map(|s| s.as_str()); + execute_read(repo_dir, path, lines) + } + "list_directory" => { + let path = tc.params.get("path").map(|s| s.as_str()).unwrap_or(""); + execute_list_directory(repo_dir, path) + } + _ => format!("Unknown tool: {}", tc.name), + }; + results.push((tc.name.clone(), result)); + } + + let tool_response = format_tool_responses(&results); + let turns_used = turn + 1; + let turns_remaining = MAX_TURNS - turns_used; + let user_msg = format!( + "{tool_response}\n[Turns used: {turns_used}/{MAX_TURNS}, Turns remaining: {turns_remaining}]" + ); + + messages.push(Message { + role: "user".to_string(), + content: user_msg, + }); + } + + // If we exhausted turns without finish, return empty + warn!("WarpGrep exhausted all turns without finish"); + Ok(CodeSearchResponse { hits: vec![] }) + } +} + +#[async_trait] +impl WarpGrepSearch for WarpGrepService { + async fn search( + &self, + repo_dir: &Path, + query: &str, + ) -> Result { + self.search_impl(repo_dir, query) + .await + .map_err(CodeSearchError::Other) + } +} + +fn build_file_tree(repo_dir: &Path) -> Result { + let mut tree = String::new(); + collect_files(repo_dir, repo_dir, &mut tree, 0, 10000)?; + Ok(tree) +} + +fn collect_files( + base: &Path, + dir: &Path, + output: &mut String, + count: usize, + limit: usize, +) -> Result { + if count >= limit { + output.push_str("... (file list truncated)\n"); + return Ok(count); + } + + let mut entries: Vec<_> = fs::read_dir(dir)? + .filter_map(|e| e.ok()) + .filter(|e| { + let name = e.file_name().to_string_lossy().to_string(); + !name.starts_with('.') && name != "node_modules" && name != "target" && name != "vendor" + }) + .collect(); + entries.sort_by_key(|e| e.file_name()); + + let mut current_count = count; + for entry in entries { + if current_count >= limit { + output.push_str("... (file list truncated)\n"); + return Ok(current_count); + } + + let path = entry.path(); + let relative = path + .strip_prefix(base) + .unwrap_or(&path) + .to_string_lossy() + .to_string(); + + if entry.file_type().map(|t| t.is_dir()).unwrap_or(false) { + output.push_str(&relative); + output.push_str("/\n"); + current_count += 1; + current_count = collect_files(base, &path, output, current_count, limit)?; + } else { + output.push_str(&relative); + output.push('\n'); + current_count += 1; + } + } + + Ok(current_count) +} + +fn parse_finish_to_response(repo_dir: &Path, content: &str) -> Result { + let parsed = parse_finish_output(content); + let mut hits = Vec::new(); + + for (filepath, ranges) in parsed { + for (start, end) in &ranges { + let body = execute_read(repo_dir, &filepath, Some(&format!("{start}-{end}"))); + let language = detect_language(&filepath); + let chunk_id = format!("{filepath}:{start}-{end}"); + + hits.push(CodeSearchHit { + scores: CodeSearchScores { + rrf: 0.5, + bm25: 0.0, + embedding: 0.0, + }, + doc: CodeSearchDocument { + file_id: filepath.clone(), + chunk_id, + body, + filepath: filepath.clone(), + git_url: String::new(), + commit: None, + language, + start_line: Some(*start), + }, + }); + } + } + + Ok(CodeSearchResponse { hits }) +} + +fn detect_language(filepath: &str) -> String { + let ext = filepath.rsplit('.').next().unwrap_or(""); + match ext { + "rs" => "rust", + "py" => "python", + "js" => "javascript", + "ts" => "typescript", + "tsx" => "typescriptreact", + "jsx" => "javascriptreact", + "go" => "go", + "java" => "java", + "c" | "h" => "c", + "cpp" | "cc" | "cxx" | "hpp" => "cpp", + "rb" => "ruby", + "php" => "php", + "swift" => "swift", + "kt" | "kts" => "kotlin", + "cs" => "csharp", + "scala" => "scala", + "sh" | "bash" => "shellscript", + "lua" => "lua", + "r" | "R" => "r", + "sql" => "sql", + "toml" => "toml", + "yaml" | "yml" => "yaml", + "json" => "json", + "xml" => "xml", + "html" => "html", + "css" => "css", + "md" => "markdown", + _ => "plaintext", + } + .to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_language() { + assert_eq!(detect_language("src/main.rs"), "rust"); + assert_eq!(detect_language("app.py"), "python"); + assert_eq!(detect_language("index.ts"), "typescript"); + assert_eq!(detect_language("unknown"), "plaintext"); + } +} diff --git a/crates/tabby/src/services/warpgrep/tool_executor.rs b/crates/tabby/src/services/warpgrep/tool_executor.rs new file mode 100644 index 000000000000..263d785f2534 --- /dev/null +++ b/crates/tabby/src/services/warpgrep/tool_executor.rs @@ -0,0 +1,353 @@ +use std::collections::HashMap; +use std::fs; +use std::path::Path; +use std::process::Command; + +use tracing::debug; + +pub struct ToolCall { + pub name: String, + pub params: HashMap, +} + +/// Validate that a path does not escape the repo directory via traversal. +fn validate_path(repo_dir: &Path, relative: &str) -> Option { + let joined = repo_dir.join(relative); + // Canonicalize both to resolve symlinks and ".." + let canonical = joined.canonicalize().ok()?; + let base = repo_dir.canonicalize().ok()?; + if canonical.starts_with(&base) { + Some(canonical) + } else { + None + } +} + +pub fn execute_ripgrep(repo_dir: &Path, pattern: &str, path: Option<&str>) -> String { + // Validate path if provided + let search_dir = if let Some(p) = path { + match validate_path(repo_dir, p) { + Some(validated) => validated, + None => return format!("Invalid path: {p}"), + } + } else { + repo_dir.to_path_buf() + }; + + let mut cmd = Command::new("grep"); + // Use -F (fixed string) instead of -E (regex) to prevent ReDoS from LLM-supplied patterns + cmd.arg("-rn").arg("-F").arg(pattern); + cmd.arg(&search_dir); + + match cmd.output() { + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + if stdout.is_empty() { + return "No matches found.".to_string(); + } + // Strip the repo_dir prefix from output paths + let prefix = repo_dir.to_string_lossy(); + let cleaned: String = stdout + .lines() + .take(100) + .map(|line| { + line.strip_prefix(prefix.as_ref()) + .and_then(|s| s.strip_prefix('/')) + .unwrap_or(line) + }) + .collect::>() + .join("\n"); + truncate_string(cleaned, 4000) + } + Err(e) => { + debug!("grep error: {}", e); + format!("Error: {e}") + } + } +} + +pub fn execute_read(repo_dir: &Path, path: &str, lines: Option<&str>) -> String { + let file_path = match validate_path(repo_dir, path) { + Some(p) => p, + None => return format!("Invalid path: {path}"), + }; + match fs::read_to_string(&file_path) { + Ok(content) => { + if let Some(line_spec) = lines { + extract_line_ranges(&content, line_spec) + } else { + truncate_string(content, 8000) + } + } + Err(_) => format!("File not found: {path}"), + } +} + +pub fn execute_list_directory(repo_dir: &Path, path: &str) -> String { + let dir_path = if path.is_empty() { + repo_dir.to_path_buf() + } else { + match validate_path(repo_dir, path) { + Some(p) => p, + None => return format!("Invalid path: {path}"), + } + }; + + match fs::read_dir(&dir_path) { + Ok(entries) => { + let mut items: Vec = entries + .filter_map(|e| e.ok()) + .map(|e| { + let name = e.file_name().to_string_lossy().to_string(); + if e.file_type().map(|t| t.is_dir()).unwrap_or(false) { + format!("{name}/") + } else { + name + } + }) + .collect(); + items.sort(); + truncate_string(items.join("\n"), 4000) + } + Err(_) => format!("Directory not found: {path}"), + } +} + +pub fn parse_tool_calls(content: &str) -> Vec { + let mut calls = Vec::new(); + let tool_names = ["ripgrep", "read", "list_directory", "finish"]; + + for tool_name in &tool_names { + let open_tag = format!("<{tool_name}>"); + let close_tag = format!(""); + + let mut search_from = 0; + while let Some(start) = content[search_from..].find(&open_tag) { + let abs_start = search_from + start + open_tag.len(); + if let Some(end) = content[abs_start..].find(&close_tag) { + let inner = &content[abs_start..abs_start + end]; + let params = parse_xml_params(inner); + calls.push(ToolCall { + name: tool_name.to_string(), + params, + }); + search_from = abs_start + end + close_tag.len(); + } else { + break; + } + } + } + + calls +} + +fn parse_xml_params(inner: &str) -> HashMap { + let mut params = HashMap::new(); + let param_names = ["pattern", "path", "glob", "lines", "result"]; + + for param_name in ¶m_names { + let open = format!("<{param_name}>"); + let close = format!(""); + + if let Some(start) = inner.find(&open) { + let val_start = start + open.len(); + if let Some(end) = inner[val_start..].find(&close) { + let value = inner[val_start..val_start + end].trim().to_string(); + params.insert(param_name.to_string(), value); + } + } + } + + params +} + +pub fn format_tool_responses(results: &[(String, String)]) -> String { + let mut output = String::new(); + for (name, result) in results { + output.push_str(&format!( + "\n<{name}>\n{result}\n\n\n" + )); + } + output +} + +pub fn parse_finish_output(content: &str) -> Vec<(String, Vec<(usize, usize)>)> { + let mut results = Vec::new(); + + // Find tag and extract its content + let finish_content = if let Some(start) = content.find("") { + let after = &content[start + 8..]; + if let Some(end) = after.find("") { + &after[..end] + } else { + after + } + } else { + content + }; + + // Extract result tag content + let result_content = if let Some(start) = finish_content.find("") { + let after = &finish_content[start + 8..]; + if let Some(end) = after.find("") { + &after[..end] + } else { + after + } + } else { + finish_content + }; + + for line in result_content.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + + // Format: path:start-end,start-end + if let Some((path, ranges_str)) = line.rsplit_once(':') { + let mut ranges = Vec::new(); + for range in ranges_str.split(',') { + let range = range.trim(); + if let Some((start_str, end_str)) = range.split_once('-') { + if let (Ok(start), Ok(end)) = ( + start_str.trim().parse::(), + end_str.trim().parse::(), + ) { + ranges.push((start, end)); + } + } + } + if !ranges.is_empty() { + results.push((path.to_string(), ranges)); + } + } + } + + results +} + +fn extract_line_ranges(content: &str, line_spec: &str) -> String { + let lines: Vec<&str> = content.lines().collect(); + let mut output = String::new(); + + for range in line_spec.split(',') { + let range = range.trim(); + if let Some((start_str, end_str)) = range.split_once('-') { + if let (Ok(start), Ok(end)) = ( + start_str.trim().parse::(), + end_str.trim().parse::(), + ) { + let start = start.saturating_sub(1); // 1-indexed to 0-indexed + let end = end.min(lines.len()); + for (i, line) in lines[start..end].iter().enumerate() { + output.push_str(&format!("{}:{}\n", start + i + 1, line)); + } + } + } + } + + truncate_string(output, 8000) +} + +fn truncate_string(s: String, max_len: usize) -> String { + if s.len() <= max_len { + return s; + } + // Find a char boundary at or before max_len to avoid panicking on multi-byte UTF-8 + let boundary = s + .char_indices() + .map(|(i, _)| i) + .take_while(|&i| i <= max_len) + .last() + .unwrap_or(0); + let mut truncated = s[..boundary].to_string(); + truncated.push_str("\n... (truncated)"); + truncated +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_tool_calls_ripgrep() { + let content = r#"I'll search for the pattern. + +fn main +src +"#; + let calls = parse_tool_calls(content); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "ripgrep"); + assert_eq!(calls[0].params.get("pattern").unwrap(), "fn main"); + assert_eq!(calls[0].params.get("path").unwrap(), "src"); + } + + #[test] + fn test_parse_tool_calls_finish() { + let content = r#" + +src/main.rs:1-50,80-100 +src/lib.rs:10-30 + +"#; + let calls = parse_tool_calls(content); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "finish"); + } + + #[test] + fn test_parse_finish_output() { + let content = r#" + +src/main.rs:1-50,80-100 +src/lib.rs:10-30 + +"#; + let results = parse_finish_output(content); + assert_eq!(results.len(), 2); + assert_eq!(results[0].0, "src/main.rs"); + assert_eq!(results[0].1, vec![(1, 50), (80, 100)]); + assert_eq!(results[1].0, "src/lib.rs"); + assert_eq!(results[1].1, vec![(10, 30)]); + } + + #[test] + fn test_parse_tool_calls_multiple() { + let content = r#" +struct Config + + +src/config.rs +1-50 +"#; + let calls = parse_tool_calls(content); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "ripgrep"); + assert_eq!(calls[1].name, "read"); + assert_eq!(calls[1].params.get("lines").unwrap(), "1-50"); + } + + #[test] + fn test_format_tool_responses() { + let results = vec![( + "ripgrep".to_string(), + "src/main.rs:1:fn main() {}".to_string(), + )]; + let output = format_tool_responses(&results); + assert!(output.contains("")); + assert!(output.contains("")); + assert!(output.contains("src/main.rs:1:fn main() {}")); + } + + #[test] + fn test_extract_line_ranges() { + let content = "line1\nline2\nline3\nline4\nline5\n"; + let result = extract_line_ranges(content, "2-4"); + assert!(result.contains("2:line2")); + assert!(result.contains("3:line3")); + assert!(result.contains("4:line4")); + assert!(!result.contains("1:line1")); + } +} diff --git a/ee/tabby-webserver/src/service/answer.rs b/ee/tabby-webserver/src/service/answer.rs index ecf1e0a3d36b..63ad83ae6af3 100644 --- a/ee/tabby-webserver/src/service/answer.rs +++ b/ee/tabby-webserver/src/service/answer.rs @@ -522,6 +522,7 @@ mod tests { Some(code.clone()), Some(doc.clone()), serper, + None, repo, settings, )); @@ -586,6 +587,7 @@ mod tests { Some(code.clone()), Some(doc.clone()), serper, + None, repo, settings, )); @@ -650,6 +652,7 @@ mod tests { Some(code.clone()), Some(doc.clone()), serper, + None, repo, settings, )); diff --git a/ee/tabby-webserver/src/service/page.rs b/ee/tabby-webserver/src/service/page.rs index 8e49da85b674..cf3803fd2f42 100644 --- a/ee/tabby-webserver/src/service/page.rs +++ b/ee/tabby-webserver/src/service/page.rs @@ -1115,6 +1115,7 @@ mod tests { Some(code.clone()), Some(doc.clone()), serper, + None, repo_service.clone(), settings, )); diff --git a/ee/tabby-webserver/src/service/retrieval.rs b/ee/tabby-webserver/src/service/retrieval.rs index 143dbc35b00a..b1cba027ad2d 100644 --- a/ee/tabby-webserver/src/service/retrieval.rs +++ b/ee/tabby-webserver/src/service/retrieval.rs @@ -9,7 +9,7 @@ use std::{ use tabby_common::api::{ code::{ CodeSearch, CodeSearchError, CodeSearchHit, CodeSearchParams, CodeSearchQuery, - CodeSearchScores, + CodeSearchScores, WarpGrepSearch, }, structured_doc::{DocSearch, DocSearchDocument, DocSearchError, DocSearchHit}, }; @@ -33,6 +33,7 @@ pub struct RetrievalService { code: Option>, doc: Option>, serper: Option>, + warpgrep: Option>, repository: Arc, settings: Arc, } @@ -42,6 +43,7 @@ impl RetrievalService { code: Option>, doc: Option>, serper: Option>, + warpgrep: Option>, repository: Arc, settings: Arc, ) -> Self { @@ -49,6 +51,7 @@ impl RetrievalService { code, doc, serper, + warpgrep, repository, settings, } @@ -132,33 +135,63 @@ impl RetrievalService { params: &CodeSearchParams, override_params: Option<&CodeSearchParamsOverrideInput>, ) -> Vec { - let Some(code) = self.code.as_ref() else { - return vec![]; - }; - - let query = CodeSearchQuery::new( - input.filepath.clone(), - input.language.clone(), - helper.rewrite_tag(&input.content), - repository.source_id.clone(), - ); + let content = helper.rewrite_tag(&input.content); + let mut all_hits = Vec::new(); + + // 1. Tantivy code search + if let Some(code) = self.code.as_ref() { + let query = CodeSearchQuery::new( + input.filepath.clone(), + input.language.clone(), + content.clone(), + repository.source_id.clone(), + ); + + let mut params = params.clone(); + if let Some(override_params) = override_params { + override_params.override_params(&mut params); + } - let mut params = params.clone(); - if let Some(override_params) = override_params { - override_params.override_params(&mut params); + match code.search_in_language(query, params).await { + Ok(docs) => { + let merged = merge_code_snippets(repository, docs.hits).await; + all_hits.extend(merged); + } + Err(err) => { + if let CodeSearchError::NotReady = err { + debug!("Code search is not ready yet"); + } else { + warn!("Failed to search code: {:?}", err); + } + } + } } - match code.search_in_language(query, params).await { - Ok(docs) => merge_code_snippets(repository, docs.hits).await, - Err(err) => { - if let CodeSearchError::NotReady = err { - debug!("Code search is not ready yet"); - } else { - warn!("Failed to search code: {:?}", err); + // 2. WarpGrep supplemental search + if let Some(warpgrep) = self.warpgrep.as_ref() { + match warpgrep.search(&repository.dir, &content).await { + Ok(resp) => { + debug!("WarpGrep returned {} hits", resp.hits.len()); + // Deduplicate: skip WarpGrep hits that overlap with existing + for mut hit in resp.hits { + hit.doc.git_url.clone_from(&repository.git_url); + let dominated = all_hits.iter().any(|existing| { + existing.doc.filepath == hit.doc.filepath + && existing.doc.start_line == hit.doc.start_line + }); + if !dominated { + all_hits.push(hit); + } + } + } + Err(err) => { + warn!("WarpGrep search failed (non-fatal): {:?}", err); } - vec![] } } + + all_hits.sort_by(|a, b| b.scores.rrf.total_cmp(&a.scores.rrf)); + all_hits } pub async fn collect_relevant_docs( @@ -249,10 +282,11 @@ pub fn create( code: Option>, doc: Option>, serper: Option>, + warpgrep: Option>, repository: Arc, settings: Arc, ) -> RetrievalService { - RetrievalService::new(code, doc, serper, repository, settings) + RetrievalService::new(code, doc, serper, warpgrep, repository, settings) } /// Combine code snippets from search results rather than utilizing multiple hits: @@ -590,7 +624,8 @@ mod tests { let repo_service = make_repository_service(db.clone()).await.unwrap(); let settings = Arc::new(setting::create(db)); - let service = RetrievalService::new(Some(code), Some(doc), None, repo_service, settings); + let service = + RetrievalService::new(Some(code), Some(doc), None, None, repo_service, settings); // Test Case 1: Basic code collection let input = make_code_query_input(Some(&test_repo.source_id), Some(&test_repo.git_url)); @@ -646,6 +681,7 @@ mod tests { Some(code.clone()), Some(doc.clone()), serper, + None, repo, settings, ); @@ -815,6 +851,7 @@ mod tests { Some(code.clone()), Some(doc.clone()), serper, + None, repo_service.clone(), settings, )); diff --git a/ee/tabby-webserver/src/service/thread.rs b/ee/tabby-webserver/src/service/thread.rs index acffce93a41f..2ab578960708 100644 --- a/ee/tabby-webserver/src/service/thread.rs +++ b/ee/tabby-webserver/src/service/thread.rs @@ -704,6 +704,7 @@ mod tests { Some(code.clone()), Some(doc.clone()), serper, + None, repo, settings, )); diff --git a/ee/tabby-webserver/src/webserver.rs b/ee/tabby-webserver/src/webserver.rs index d71cb62b5758..c9bba1cd190d 100644 --- a/ee/tabby-webserver/src/webserver.rs +++ b/ee/tabby-webserver/src/webserver.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use axum::Router; use tabby_common::{ api::{ - code::CodeSearch, + code::{CodeSearch, WarpGrepSearch}, event::{ComposedLogger, EventLogger}, structured_doc::DocSearch, }, @@ -66,6 +66,7 @@ impl Webserver { completion: Option>, docsearch: Option>, serper_factory_fn: impl Fn(&str) -> Box, + warpgrep_factory_fn: impl Fn(&str) -> Box, ) -> (Router, Router) { let serper: Option> = if let Ok(api_key) = std::env::var("SERPER_API_KEY") { @@ -75,6 +76,14 @@ impl Webserver { None }; + let warpgrep: Option> = + if let Ok(api_key) = std::env::var("MORPH_API_KEY") { + debug!("Morph API key found, enabling WarpGrep code search..."); + Some(warpgrep_factory_fn(&api_key)) + } else { + None + }; + let db = self.db.clone(); let job: Arc = Arc::new(job::create(db.clone()).await); @@ -118,6 +127,7 @@ impl Webserver { code.clone(), docsearch.clone(), serper, + warpgrep, repository.clone(), setting.clone(), ));