Skip to content

Commit 3866306

Browse files
feat(controller): use better abstract methods
1 parent af840c3 commit 3866306

2 files changed

Lines changed: 39 additions & 12 deletions

File tree

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import List
1+
from typing import List, Dict
22
from modular_search.controller.core import SearchController
33
from modular_search.blocks.codebase import CodebaseSearchBlock, CodebaseSearchResult
44

5+
56
class CodebaseSearchController(SearchController[CodebaseSearchResult]):
67
"""
78
Codebase Search Controller Class
@@ -12,7 +13,15 @@ def __init__(self, search_block: CodebaseSearchBlock):
1213
"CodebaseSearchBlock": search_block
1314
})
1415

15-
def search(self, query: str) -> List[CodebaseSearchResult]:
16-
# since it's just one block, it makes our life easier
17-
all_results = super().internal_search(query, ["CodebaseSearchBlock"])
18-
return all_results["CodebaseSearchBlock"]
16+
def select_blocks(self, query: str) -> List[str]:
17+
"""
18+
Selects the unit search blocks to be used for the given query.
19+
For CodebaseSearchController, we always use the CodebaseSearchBlock.
20+
"""
21+
return ["CodebaseSearchBlock"]
22+
23+
def aggregate(self, search_results: Dict[str, List[CodebaseSearchResult]]) -> List[CodebaseSearchResult]:
24+
""" Aggregates the search results from the CodebaseSearchBlock.
25+
Since we only have one block, we can return the results directly.
26+
"""
27+
return search_results["CodebaseSearchBlock"]

src/modular_search/controller/core.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,17 @@ def __init__(self, unit_blocks: Union[Dict[str, UnitSearchBlock[O]], List[UnitSe
7676
unit_blocks = block_dict
7777

7878
self.unit_blocks = unit_blocks
79-
79+
8080
@abstractmethod
81-
def search(self, query: str) -> List[O]:
81+
def select_blocks(self, query: str) -> List[str]:
82+
"""
83+
Selects the unit search blocks to be used for the given query.
84+
This method should be implemented by subclasses to define
85+
how blocks are selected based on the query.
86+
"""
8287
pass
8388

84-
def internal_search(self, query: str, active_blocks: List[str]) -> Dict[str, List[O]]:
85-
missing_blocks = set(active_blocks) - set(self.unit_blocks.keys())
86-
if missing_blocks:
87-
raise ValueError(f"Active blocks {missing_blocks} are not registered in the controller.")
88-
89+
def block_search(self, query: str, active_blocks: List[str]) -> Dict[str, List[O]]:
8990
# Dispatch to active Unit Search Blocks
9091
all_results = {}
9192
for block_name in active_blocks:
@@ -94,4 +95,21 @@ def internal_search(self, query: str, active_blocks: List[str]) -> Dict[str, Lis
9495
all_results[block_name] = results
9596

9697
return all_results
98+
99+
@abstractmethod
100+
def aggregate(self, search_results: Dict[str, List[O]]) -> List[O]:
101+
pass
102+
103+
def search(self, query: str) -> List[O]:
104+
active_blocks = self.select_blocks(query)
105+
106+
missing_blocks = set(active_blocks) - set(self.unit_blocks.keys())
107+
if missing_blocks:
108+
raise ValueError(f"Active blocks {missing_blocks} are not registered in the controller.")
109+
110+
search_results = self.block_search(query, active_blocks)
111+
112+
aggregated_results = self.aggregate(search_results)
113+
114+
return aggregated_results
97115

0 commit comments

Comments
 (0)