Skip to content

Commit ad1310a

Browse files
committed
refactor: extract symbol retrieval policy
1 parent 4ab276a commit ad1310a

4 files changed

Lines changed: 244 additions & 195 deletions

File tree

src/core/context.rs

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::path::Path;
66
use std::path::PathBuf;
77

88
use crate::core::function_chunker::find_enclosing_boundary_line;
9-
use crate::core::{ContextProvenance, SymbolIndex};
9+
use crate::core::{ContextProvenance, SymbolContextRetriever, SymbolIndex, SymbolRetrievalPolicy};
1010

1111
#[derive(Debug, Clone, Serialize, Deserialize)]
1212
pub struct LLMContextChunk {
@@ -271,13 +271,13 @@ impl ContextFetcher {
271271
}
272272
}
273273

274-
for location in index.graph_related_locations(
275-
file_path,
276-
symbols,
277-
graph_hops,
278-
max_locations,
279-
graph_max_files,
280-
) {
274+
let retriever = SymbolContextRetriever::new(
275+
index,
276+
SymbolRetrievalPolicy::new(max_locations, graph_hops, graph_max_files),
277+
);
278+
let related_locations = retriever.related_symbol_locations(file_path, symbols);
279+
280+
for location in related_locations.definition_locations {
281281
if &location.file_path == file_path {
282282
continue;
283283
}
@@ -290,13 +290,7 @@ impl ContextFetcher {
290290
chunks.push(chunk);
291291
}
292292

293-
for location in index.multi_hop_locations(
294-
file_path,
295-
symbols,
296-
max_locations,
297-
graph_hops,
298-
graph_max_files,
299-
) {
293+
for location in related_locations.reference_locations {
300294
if &location.file_path == file_path {
301295
continue;
302296
}

src/core/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@ pub use semantic::{
4545
load_semantic_feedback_store, refresh_semantic_index, save_semantic_feedback_store,
4646
semantic_context_for_diff, SemanticFeedbackExample, SemanticFeedbackStore,
4747
};
48-
pub use symbol_index::SymbolIndex;
48+
pub use symbol_index::{SymbolContextRetriever, SymbolIndex, SymbolRetrievalPolicy};

src/core/symbol_index.rs

Lines changed: 6 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,14 @@ use std::io::{BufRead, BufReader, Read, Write};
1010
use std::path::{Path, PathBuf};
1111
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
1212

13-
use crate::core::symbol_graph::{RankedSymbol, SymbolGraph};
13+
use crate::core::symbol_graph::SymbolGraph;
1414
use crate::core::ContextProvenance;
1515

16+
#[path = "symbol_index/retrieval.rs"]
17+
mod retrieval;
18+
19+
pub use retrieval::{SymbolContextRetriever, SymbolRetrievalPolicy};
20+
1621
#[derive(Debug, Clone)]
1722
pub struct SymbolLocation {
1823
pub file_path: PathBuf,
@@ -387,155 +392,6 @@ impl SymbolIndex {
387392
self.symbols.get(symbol)
388393
}
389394

390-
pub fn graph_related_locations(
391-
&self,
392-
current_file: &Path,
393-
symbols: &[String],
394-
max_hops: usize,
395-
max_locations: usize,
396-
max_files: usize,
397-
) -> Vec<SymbolLocation> {
398-
let Some(graph) = &self.symbol_graph else {
399-
return Vec::new();
400-
};
401-
if symbols.is_empty() || max_locations == 0 || max_files == 0 || max_hops == 0 {
402-
return Vec::new();
403-
}
404-
405-
let ranked = graph.related_symbols(
406-
symbols,
407-
max_hops,
408-
max_locations.saturating_mul(max_files).max(max_locations),
409-
);
410-
411-
let mut results = Vec::new();
412-
let mut seen_locations = HashSet::new();
413-
let mut seen_files = HashSet::new();
414-
415-
for ranked_symbol in ranked {
416-
if ranked_symbol.file_path == current_file {
417-
continue;
418-
}
419-
if seen_files.len() >= max_files && !seen_files.contains(&ranked_symbol.file_path) {
420-
continue;
421-
}
422-
423-
let Some(mut location) = self.lookup_ranked_symbol_location(&ranked_symbol) else {
424-
continue;
425-
};
426-
427-
let location_key = format!(
428-
"{}:{}:{}",
429-
location.file_path.display(),
430-
location.line_range.0,
431-
location.line_range.1
432-
);
433-
if !seen_locations.insert(location_key) {
434-
continue;
435-
}
436-
437-
let relation_path = ranked_symbol
438-
.relation_path
439-
.iter()
440-
.map(|relation| relation.as_label().to_string())
441-
.collect::<Vec<_>>();
442-
location.provenance = Some(ContextProvenance::symbol_graph_path(
443-
relation_path.clone(),
444-
ranked_symbol.hops,
445-
ranked_symbol.relevance_score,
446-
));
447-
location.snippet = format!(
448-
"[Graph: {}, hops={}, relevance={:.2}]\n{}",
449-
relation_path.join(" -> "),
450-
ranked_symbol.hops,
451-
ranked_symbol.relevance_score,
452-
location.snippet
453-
);
454-
seen_files.insert(location.file_path.clone());
455-
results.push(location);
456-
}
457-
458-
results
459-
}
460-
461-
pub fn multi_hop_locations(
462-
&self,
463-
current_file: &Path,
464-
symbols: &[String],
465-
max_locations: usize,
466-
max_hops: usize,
467-
max_files: usize,
468-
) -> Vec<SymbolLocation> {
469-
if symbols.is_empty() || max_files == 0 {
470-
return Vec::new();
471-
}
472-
473-
let mut direct_files = HashSet::new();
474-
let mut locations = Vec::new();
475-
let mut seen_locations = HashSet::new();
476-
477-
for symbol in symbols {
478-
if let Some(entries) = self.lookup(symbol) {
479-
for location in entries.iter().take(max_locations) {
480-
let location_key = format!(
481-
"{}:{}:{}",
482-
location.file_path.display(),
483-
location.line_range.0,
484-
location.line_range.1
485-
);
486-
if seen_locations.insert(location_key) {
487-
direct_files.insert(location.file_path.clone());
488-
locations.push(location.clone());
489-
}
490-
}
491-
}
492-
}
493-
494-
let mut queue: std::collections::VecDeque<(PathBuf, usize)> =
495-
std::collections::VecDeque::new();
496-
let mut seen_files = HashSet::new();
497-
498-
for file in direct_files {
499-
if file == current_file {
500-
continue;
501-
}
502-
seen_files.insert(file.clone());
503-
queue.push_back((file, 0));
504-
}
505-
506-
while let Some((file, depth)) = queue.pop_front() {
507-
if depth >= max_hops {
508-
continue;
509-
}
510-
511-
for neighbor in self.neighbor_files(&file) {
512-
if neighbor == current_file {
513-
continue;
514-
}
515-
if !seen_files.insert(neighbor.clone()) {
516-
continue;
517-
}
518-
queue.push_back((neighbor, depth + 1));
519-
}
520-
}
521-
522-
for file in seen_files.into_iter().take(max_files) {
523-
if locations.iter().any(|location| location.file_path == file) {
524-
continue;
525-
}
526-
if let Some(summary) = self.file_summaries.get(&file) {
527-
locations.push(SymbolLocation {
528-
file_path: file,
529-
line_range: (1, summary.line_count.max(1)),
530-
snippet: format!("[Dependency graph context]\n{}", summary.snippet),
531-
provenance: Some(ContextProvenance::DependencyGraphNeighborhood),
532-
});
533-
}
534-
}
535-
536-
locations
537-
}
538-
539395
pub fn files_indexed(&self) -> usize {
540396
self.files_indexed
541397
}
@@ -589,17 +445,6 @@ impl SymbolIndex {
589445
self.file_summaries.get(file).map(|s| s.snippet.as_str())
590446
}
591447

592-
fn neighbor_files(&self, file: &Path) -> HashSet<PathBuf> {
593-
let mut neighbors = HashSet::new();
594-
if let Some(deps) = self.dependency_graph.get(file) {
595-
neighbors.extend(deps.iter().cloned());
596-
}
597-
if let Some(reverse) = self.reverse_dependency_graph.get(file) {
598-
neighbors.extend(reverse.iter().cloned());
599-
}
600-
neighbors
601-
}
602-
603448
fn build_graph_from_sources(&mut self, sources: &HashMap<PathBuf, String>) {
604449
if sources.is_empty() {
605450
self.symbol_graph = None;
@@ -613,24 +458,6 @@ impl SymbolIndex {
613458
self.symbol_graph = Some(graph);
614459
}
615460
}
616-
617-
fn lookup_ranked_symbol_location(&self, ranked: &RankedSymbol) -> Option<SymbolLocation> {
618-
self.lookup(&ranked.name).and_then(|locations| {
619-
locations
620-
.iter()
621-
.filter(|location| location.file_path == ranked.file_path)
622-
.min_by_key(|location| {
623-
let start = location.line_range.0;
624-
let end = location.line_range.1;
625-
if ranked.line < start {
626-
start - ranked.line
627-
} else {
628-
ranked.line.saturating_sub(end)
629-
}
630-
})
631-
.cloned()
632-
})
633-
}
634461
}
635462

636463
fn normalized_extension_set(lsp_languages: &HashMap<String, String>) -> HashSet<String> {

0 commit comments

Comments
 (0)