Skip to content

Commit f725c7f

Browse files
Zhe YuDavidyz
authored andcommitted
test(cli): tests for rewriter.
1 parent 7434ae8 commit f725c7f

4 files changed

Lines changed: 38 additions & 8 deletions

File tree

src/vectorcode/rewriter/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
__all__ = ["RewriterBase", "OpenAIRewriter"]
1212

1313

14-
class RewriterError(Exception):
14+
class RewriterError(Exception): # pragma: nocover
1515
pass
1616

1717

@@ -21,8 +21,9 @@ def get_rewriter(configs: Config) -> Optional[RewriterBase]:
2121
return None
2222
if configs.rewriter == "RewriterBase":
2323
raise RewriterError("RewriterBase is not a valid rewriter!")
24-
rewriter_cls = getattr(sys.modules[__name__], configs.rewriter)
25-
if issubclass(rewriter_cls, RewriterBase):
26-
logger.info(f"Loaded {configs.rewriter}")
27-
return rewriter_cls(configs)
24+
if hasattr(sys.modules[__name__], configs.rewriter):
25+
rewriter_cls = getattr(sys.modules[__name__], configs.rewriter)
26+
if issubclass(rewriter_cls, RewriterBase):
27+
logger.info(f"Loaded {configs.rewriter}")
28+
return rewriter_cls(configs)
2829
raise RewriterError(f"Failed to find {configs.rewriter}!")

src/vectorcode/rewriter/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from vectorcode.cli_utils import Config
44

55

6-
class RewriterBase(ABC):
6+
class RewriterBase(ABC): # pragma: nocover
77
def __init__(self, config: Config) -> None:
88
super().__init__()
99
self.config = config

src/vectorcode/rewriter/openai.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
from typing import override
32

43
import openai
54
from openai.types.chat import ChatCompletion
@@ -49,7 +48,6 @@ def __init__(self, config: Config) -> None:
4948
No hallucinations.
5049
"""
5150

52-
@override
5351
async def rewrite(self, original_query: list[str]):
5452
comp: ChatCompletion = self.client.beta.chat.completions.parse(
5553
messages=[

tests/rewriter/test_rewriter.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import pytest
4+
5+
from vectorcode.cli_utils import Config
6+
from vectorcode.rewriter import RewriterError, get_rewriter
7+
8+
9+
def test_get_rewriter():
10+
assert get_rewriter(Config()) is None
11+
12+
13+
def test_get_rewriter_base():
14+
with pytest.raises(RewriterError):
15+
get_rewriter(Config(rewriter="RewriterBase"))
16+
17+
18+
def test_get_openai_rewriter():
19+
with (
20+
patch("vectorcode.rewriter.OpenAIRewriter") as mock_openai_cls,
21+
patch("vectorcode.rewriter.issubclass") as mock_issubclass,
22+
):
23+
mock_rewriter = MagicMock()
24+
mock_openai_cls.return_value = mock_rewriter
25+
mock_issubclass.return_value = True
26+
assert get_rewriter(Config(rewriter="OpenAIRewriter")) == mock_rewriter
27+
28+
29+
def test_get_faulty_rewriter():
30+
with pytest.raises(RewriterError):
31+
get_rewriter(Config(rewriter="DummyRewriter"))

0 commit comments

Comments
 (0)