File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1111__all__ = ["RewriterBase" , "OpenAIRewriter" ]
1212
1313
14- class RewriterError (Exception ):
14+ class RewriterError (Exception ): # pragma: nocover
1515 pass
1616
1717
@@ -21,8 +21,9 @@ def get_rewriter(configs: Config) -> Optional[RewriterBase]:
2121 return None
2222 if configs .rewriter == "RewriterBase" :
2323 raise RewriterError ("RewriterBase is not a valid rewriter!" )
24- rewriter_cls = getattr (sys .modules [__name__ ], configs .rewriter )
25- if issubclass (rewriter_cls , RewriterBase ):
26- logger .info (f"Loaded { configs .rewriter } " )
27- return rewriter_cls (configs )
24+ if hasattr (sys .modules [__name__ ], configs .rewriter ):
25+ rewriter_cls = getattr (sys .modules [__name__ ], configs .rewriter )
26+ if issubclass (rewriter_cls , RewriterBase ):
27+ logger .info (f"Loaded { configs .rewriter } " )
28+ return rewriter_cls (configs )
2829 raise RewriterError (f"Failed to find { configs .rewriter } !" )
Original file line number Diff line number Diff line change 33from vectorcode .cli_utils import Config
44
55
6- class RewriterBase (ABC ):
6+ class RewriterBase (ABC ): # pragma: nocover
77 def __init__ (self , config : Config ) -> None :
88 super ().__init__ ()
99 self .config = config
Original file line number Diff line number Diff line change 11import logging
2- from typing import override
32
43import openai
54from openai .types .chat import ChatCompletion
@@ -49,7 +48,6 @@ def __init__(self, config: Config) -> None:
4948 No hallucinations.
5049"""
5150
52- @override
5351 async def rewrite (self , original_query : list [str ]):
5452 comp : ChatCompletion = self .client .beta .chat .completions .parse (
5553 messages = [
Original file line number Diff line number Diff line change 1+ from unittest .mock import MagicMock , patch
2+
3+ import pytest
4+
5+ from vectorcode .cli_utils import Config
6+ from vectorcode .rewriter import RewriterError , get_rewriter
7+
8+
9+ def test_get_rewriter ():
10+ assert get_rewriter (Config ()) is None
11+
12+
13+ def test_get_rewriter_base ():
14+ with pytest .raises (RewriterError ):
15+ get_rewriter (Config (rewriter = "RewriterBase" ))
16+
17+
18+ def test_get_openai_rewriter ():
19+ with (
20+ patch ("vectorcode.rewriter.OpenAIRewriter" ) as mock_openai_cls ,
21+ patch ("vectorcode.rewriter.issubclass" ) as mock_issubclass ,
22+ ):
23+ mock_rewriter = MagicMock ()
24+ mock_openai_cls .return_value = mock_rewriter
25+ mock_issubclass .return_value = True
26+ assert get_rewriter (Config (rewriter = "OpenAIRewriter" )) == mock_rewriter
27+
28+
29+ def test_get_faulty_rewriter ():
30+ with pytest .raises (RewriterError ):
31+ get_rewriter (Config (rewriter = "DummyRewriter" ))
You can’t perform that action at this time.
0 commit comments