File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -97,6 +97,7 @@ class Config:
9797 filetype_map : dict [str , list [str ]] = field (default_factory = dict )
9898 encoding : str = "utf8"
9999 hooks : bool = False
100+ use_rewriter : bool = False
100101 rewriter : Optional [str ] = None
101102 rewriter_params : dict [str , Any ] = field (default_factory = dict )
102103
@@ -317,6 +318,11 @@ def get_cli_parser():
317318 help = "What to include in the final output." ,
318319 default = __default_config .include ,
319320 )
321+ query_parser .add_argument (
322+ "--rewrite" ,
323+ action = "store_true" ,
324+ help = "Apply rewriter to rewrite the query before running the search." ,
325+ )
320326
321327 subparsers .add_parser ("drop" , parents = [shared_parser ], help = "Remove a collection." )
322328 hooks_parser = subparsers .add_parser (
@@ -422,6 +428,7 @@ async def parse_cli_args(args: Optional[Sequence[str]] = None):
422428 configs_items ["use_absolute_path" ] = main_args .absolute
423429 configs_items ["include" ] = [QueryInclude (i ) for i in main_args .include ]
424430 configs_items ["encoding" ] = main_args .encoding
431+ configs_items ["use_rewriter" ] = main_args .rewrite
425432 case "check" :
426433 configs_items ["check_item" ] = main_args .check_item
427434 case "init" :
Original file line number Diff line number Diff line change 1+ import logging
2+ import sys
3+ from typing import Optional
4+
5+ from vectorcode .cli_utils import Config
6+
17from .base import RewriterBase
28from .openai import OpenAIRewriter
39
10+ logger = logging .getLogger (name = __name__ )
411__all__ = ["RewriterBase" , "OpenAIRewriter" ]
12+
13+
14+ class RewriterError (Exception ):
15+ pass
16+
17+
18+ def get_rewriter (configs : Config ) -> Optional [RewriterBase ]:
19+ if configs .rewriter is None :
20+ logger .warning ("Rewriter hasn't been configured. Skipping rewriting." )
21+ return None
22+ if configs .rewriter == "RewriterBase" :
23+ raise RewriterError ("RewriterBase is not a valid rewriter!" )
24+ rewriter_cls = getattr (sys .modules [__name__ ], configs .rewriter )
25+ if issubclass (rewriter_cls , RewriterBase ):
26+ logger .info (f"Loaded { configs .rewriter } " )
27+ return rewriter_cls (configs )
28+ raise RewriterError (f"Failed to find { configs .rewriter } !" )
Original file line number Diff line number Diff line change @@ -64,7 +64,6 @@ async def rewrite(self, original_query: list[str]):
6464 return original_query
6565 choice = comp .choices [0 ].message
6666 if choice and choice .parsed :
67- print (choice .parsed )
6867 logger .debug (f"Rewritten queries to: { choice .parsed } " )
6968 return choice .parsed .keywords
7069 else :
Original file line number Diff line number Diff line change 2020 get_collection ,
2121 verify_ef ,
2222)
23- from vectorcode .rewriter import OpenAIRewriter
2423from vectorcode .subcommands .query .reranker import (
2524 RerankerError ,
2625 get_reranker ,
@@ -34,8 +33,12 @@ async def get_query_result_files(
3433) -> list [str ]:
3534 query_chunks = []
3635 if configs .query :
37- if configs .rewriter :
38- configs .query = await OpenAIRewriter (configs ).rewrite (configs .query )
36+ if configs .use_rewriter :
37+ from vectorcode .rewriter import get_rewriter # lazy import
38+
39+ rewriter = get_rewriter (configs )
40+ if rewriter is not None :
41+ configs .query = await rewriter .rewrite (configs .query )
3942 chunker = StringChunker (configs )
4043 for q in configs .query :
4144 query_chunks .extend (str (i ) for i in chunker .chunk (q ))
You can’t perform that action at this time.
0 commit comments