Skip to content

Commit 7434ae8

Browse files
Zhe YuDavidyz
authored andcommitted
refactor(cli): Stop hardcoding rewriter
1 parent 92dad1d commit 7434ae8

4 files changed

Lines changed: 37 additions & 4 deletions

File tree

src/vectorcode/cli_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class Config:
9797
filetype_map: dict[str, list[str]] = field(default_factory=dict)
9898
encoding: str = "utf8"
9999
hooks: bool = False
100+
use_rewriter: bool = False
100101
rewriter: Optional[str] = None
101102
rewriter_params: dict[str, Any] = field(default_factory=dict)
102103

@@ -317,6 +318,11 @@ def get_cli_parser():
317318
help="What to include in the final output.",
318319
default=__default_config.include,
319320
)
321+
query_parser.add_argument(
322+
"--rewrite",
323+
action="store_true",
324+
help="Apply rewriter to rewrite the query before running the search.",
325+
)
320326

321327
subparsers.add_parser("drop", parents=[shared_parser], help="Remove a collection.")
322328
hooks_parser = subparsers.add_parser(
@@ -422,6 +428,7 @@ async def parse_cli_args(args: Optional[Sequence[str]] = None):
422428
configs_items["use_absolute_path"] = main_args.absolute
423429
configs_items["include"] = [QueryInclude(i) for i in main_args.include]
424430
configs_items["encoding"] = main_args.encoding
431+
configs_items["use_rewriter"] = main_args.rewrite
425432
case "check":
426433
configs_items["check_item"] = main_args.check_item
427434
case "init":
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,28 @@
1+
import logging
2+
import sys
3+
from typing import Optional
4+
5+
from vectorcode.cli_utils import Config
6+
17
from .base import RewriterBase
28
from .openai import OpenAIRewriter
39

10+
logger = logging.getLogger(name=__name__)
411
__all__ = ["RewriterBase", "OpenAIRewriter"]
12+
13+
14+
class RewriterError(Exception):
15+
pass
16+
17+
18+
def get_rewriter(configs: Config) -> Optional[RewriterBase]:
19+
if configs.rewriter is None:
20+
logger.warning("Rewriter hasn't been configured. Skipping rewriting.")
21+
return None
22+
if configs.rewriter == "RewriterBase":
23+
raise RewriterError("RewriterBase is not a valid rewriter!")
24+
rewriter_cls = getattr(sys.modules[__name__], configs.rewriter)
25+
if issubclass(rewriter_cls, RewriterBase):
26+
logger.info(f"Loaded {configs.rewriter}")
27+
return rewriter_cls(configs)
28+
raise RewriterError(f"Failed to find {configs.rewriter}!")

src/vectorcode/rewriter/openai.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ async def rewrite(self, original_query: list[str]):
6464
return original_query
6565
choice = comp.choices[0].message
6666
if choice and choice.parsed:
67-
print(choice.parsed)
6867
logger.debug(f"Rewritten queries to: {choice.parsed}")
6968
return choice.parsed.keywords
7069
else:

src/vectorcode/subcommands/query/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
get_collection,
2121
verify_ef,
2222
)
23-
from vectorcode.rewriter import OpenAIRewriter
2423
from vectorcode.subcommands.query.reranker import (
2524
RerankerError,
2625
get_reranker,
@@ -34,8 +33,12 @@ async def get_query_result_files(
3433
) -> list[str]:
3534
query_chunks = []
3635
if configs.query:
37-
if configs.rewriter:
38-
configs.query = await OpenAIRewriter(configs).rewrite(configs.query)
36+
if configs.use_rewriter:
37+
from vectorcode.rewriter import get_rewriter # lazy import
38+
39+
rewriter = get_rewriter(configs)
40+
if rewriter is not None:
41+
configs.query = await rewriter.rewrite(configs.query)
3942
chunker = StringChunker(configs)
4043
for q in configs.query:
4144
query_chunks.extend(str(i) for i in chunker.chunk(q))

0 commit comments

Comments
 (0)