Skip to content

Commit 82a34ac

Browse files
committed
Add comprehensive tests for PythonParser functionality
1 parent fadae0d commit 82a34ac

File tree

1 file changed

+226
-0
lines changed

1 file changed

+226
-0
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
from codetide.core.models import ImportStatement
2+
from codetide.parsers.python_parser import PythonParser
3+
4+
from tree_sitter import Parser
5+
from pathlib import Path
6+
import pytest
7+
import os
8+
9+
10+
@pytest.fixture
11+
def parser() -> PythonParser:
12+
"""Provides a default instance of the PythonParser."""
13+
return PythonParser()
14+
15+
class TestPythonParser:
16+
17+
def test_initialization(self, parser: PythonParser):
18+
"""Tests the basic properties and initialization of the parser."""
19+
assert parser.language == "python"
20+
assert parser.extension == ".py"
21+
assert parser.tree_parser is not None
22+
assert isinstance(parser.tree_parser, Parser)
23+
24+
@pytest.mark.parametrize("path, expected", [
25+
("my/app/main.py", "my/app/main.py"),
26+
("my/app/__init__.py", "my/app"),
27+
("my\\app\\__init__.py", "my\\app"),
28+
("lib.py", "lib.py"),
29+
])
30+
def test_skip_init_paths(self, path, expected):
31+
"""Tests the removal of __init__.py from paths."""
32+
assert PythonParser._skip_init_paths(Path(path)) == str(Path(expected))
33+
34+
@pytest.mark.parametrize("code, substring, count", [
35+
("import os; os.getcwd()", "os", 2),
36+
("var = my_var", "var", 1),
37+
("variable = my_var", "var", 0),
38+
("def func():\n pass\nfunc()", "func", 2),
39+
("test(test)", "test", 2),
40+
("class MyTest: pass", "MyTest", 1),
41+
("a.b.c(b)", "b", 1),
42+
])
43+
def test_count_occurences_in_code(self, code, substring, count):
44+
"""Tests the regex-based word occurrence counter."""
45+
assert PythonParser.count_occurences_in_code(code, substring) == count
46+
47+
def test_get_content_indentation(self, parser: PythonParser):
48+
"""Tests the _get_content method for preserving indentation."""
49+
code = b"class MyClass:\n def method(self):\n pass"
50+
tree = parser.tree_parser.parse(code)
51+
# function_definition node
52+
method_node = tree.root_node.children[0].children[-1].children[0]
53+
54+
content_no_indent = parser._get_content(code, method_node, preserve_indentation=False)
55+
assert content_no_indent == "def method(self):\n pass"
56+
57+
content_with_indent = parser._get_content(code, method_node, preserve_indentation=True)
58+
assert content_with_indent == " def method(self):\n pass"
59+
60+
@pytest.mark.asyncio
61+
async def test_parse_file(self, parser: PythonParser, tmp_path: Path):
62+
"""Tests parsing a file from disk."""
63+
file_path = tmp_path / "test_module.py"
64+
code_content = "import os\n\nx = 10"
65+
file_path.write_text(code_content, encoding="utf-8")
66+
67+
code_file_model = await parser.parse_file(file_path)
68+
69+
assert code_file_model.file_path == str(file_path.absolute())
70+
assert len(code_file_model.imports) == 1
71+
assert code_file_model.imports[0].source == "os"
72+
assert len(code_file_model.variables) == 1
73+
assert code_file_model.variables[0].name == "x"
74+
assert code_file_model.variables[0].value == "10"
75+
76+
@pytest.mark.asyncio
77+
async def test_parse_file_with_root_path(self, parser: PythonParser, tmp_path: Path):
78+
"""Tests parsing a file with a root path to get a relative file path."""
79+
root_dir = tmp_path / "project"
80+
root_dir.mkdir()
81+
module_path = root_dir / "module"
82+
module_path.mkdir()
83+
file_path = module_path / "test.py"
84+
file_path.write_text("x = 1", encoding="utf-8")
85+
86+
code_file_model = await parser.parse_file(file_path, root_path=root_dir)
87+
88+
# Should be relative to root_dir
89+
expected_relative_path = os.path.join("module", "test.py")
90+
assert code_file_model.file_path == expected_relative_path
91+
92+
class TestPythonParserDetailed:
93+
94+
@pytest.mark.parametrize("code, expected_imports", [
95+
("import os", [ImportStatement(source='os')]),
96+
("import numpy as np", [ImportStatement(name='numpy', alias='np')]),
97+
("from pathlib import Path", [ImportStatement(source='pathlib', name='Path')]),
98+
("from collections import deque, defaultdict", [
99+
ImportStatement(source='collections', name='deque'),
100+
ImportStatement(source='collections', name='defaultdict')
101+
]),
102+
("from typing import List as L", [ImportStatement(source='typing', name='List', alias='L')]),
103+
])
104+
def test_parse_imports(self, parser: PythonParser, code, expected_imports):
105+
"""Tests various import statement formats."""
106+
file_path = Path("test.py")
107+
code_file = parser.parse_code(code.encode('utf-8'), file_path)
108+
print(f"{code_file.imports=}")
109+
assert len(code_file.imports) == len(expected_imports)
110+
for parsed, expected in zip(code_file.imports, expected_imports):
111+
assert parsed.source == expected.source
112+
assert parsed.name == expected.name
113+
assert parsed.alias == expected.alias
114+
115+
def test_parse_function(self, parser: PythonParser):
116+
"""Tests parsing of a complex function definition."""
117+
code = """
118+
@decorator1
119+
@decorator2
120+
async def my_func(a: int, b: str = "default") -> List[str]:
121+
'''docstring'''
122+
return [b] * a
123+
"""
124+
file_path = Path("test.py")
125+
code_file = parser.parse_code(code.encode('utf-8'), file_path)
126+
127+
assert len(code_file.functions) == 1
128+
func = code_file.functions[0]
129+
130+
assert func.name == "my_func"
131+
assert func.decorators == ["@decorator1", "@decorator2"]
132+
assert func.modifiers == ["async"]
133+
134+
sig = func.signature
135+
assert sig is not None
136+
assert sig.return_type == "List[str]"
137+
assert len(sig.parameters) == 2
138+
139+
param1 = sig.parameters[0]
140+
assert param1.name == "a"
141+
assert param1.type_hint == "int"
142+
assert param1.default_value is None
143+
144+
param2 = sig.parameters[1]
145+
assert param2.name == "b"
146+
assert param2.type_hint == "str"
147+
assert param2.default_value == '"default"'
148+
149+
def test_parse_class(self, parser: PythonParser):
150+
"""Tests parsing of a complex class definition."""
151+
code = """
152+
class Child(Base1, Base2):
153+
class_attr: int = 10
154+
155+
def __init__(self, name: str):
156+
self.name = name
157+
158+
@property
159+
def name_upper(self) -> str:
160+
return self.name.upper()
161+
"""
162+
file_path = Path("test.py")
163+
code_file = parser.parse_code(code.encode('utf-8'), file_path)
164+
assert len(code_file.classes) == 1
165+
cls = code_file.classes[0]
166+
167+
assert cls.name == "Child"
168+
assert "Base1" in cls.bases
169+
assert "Base2" in cls.bases
170+
171+
assert len(cls.attributes) == 1
172+
attr = cls.attributes[0]
173+
assert attr.name == "class_attr"
174+
assert attr.type_hint == "int"
175+
assert attr.value == "10"
176+
177+
assert len(cls.methods) == 2
178+
method1 = next(m for m in cls.methods if m.name == "__init__")
179+
method2 = next(m for m in cls.methods if m.name == "name_upper")
180+
181+
assert method1.name == "__init__"
182+
assert len(method1.signature.parameters) == 1 # name
183+
assert method1.decorators == []
184+
185+
assert method2.name == "name_upper"
186+
assert method2.signature.return_type == "str"
187+
assert method2.decorators == ["@property"]
188+
189+
def test_intra_file_dependencies(self, parser: PythonParser):
190+
"""Tests resolving references within a single file."""
191+
code = """
192+
from typing import List
193+
194+
class Helper:
195+
def do_work(self):
196+
return "done"
197+
198+
def process_data(items: List[str]) -> Helper:
199+
h = Helper()
200+
h.do_work()
201+
return h
202+
203+
var = process_data([])
204+
"""
205+
file_path = Path("test.py")
206+
code_file = parser.parse_code(code.encode('utf-8'), file_path)
207+
parser.resolve_intra_file_dependencies([code_file])
208+
209+
# process_data should reference List and Helper
210+
process_func = code_file.get("test.process_data")
211+
assert len(process_func.references) == 3
212+
ref_names = {ref.name for ref in process_func.references}
213+
assert "List" in ref_names
214+
assert "do_work" in ref_names
215+
216+
# Class Helper method `do_work` is referenced
217+
do_work_method = code_file.get("test.Helper.do_work")
218+
# Assert that `process_data` references `do_work`
219+
found = any(ref.unique_id == do_work_method.unique_id for ref in process_func.references)
220+
assert found
221+
assert "h.do_work" in process_func.raw # Simple check
222+
223+
# var should reference process_data
224+
var_decl = code_file.get("test.var")
225+
assert len(var_decl.references) == 1
226+
assert var_decl.references[0].unique_id == process_func.unique_id

0 commit comments

Comments
 (0)