Skip to content

Commit 20d093b

Browse files
gkorlandCopilot
andcommitted
fix: address review feedback for Kotlin analyzer
- Replace wildcard import with explicit Entity and File imports - Fix tree-sitter queries: Kotlin grammar uses 'identifier' not 'type_identifier' - Fix get_entity_name: use 'identifier' for all entity types - Separate superclass/interface in add_symbols: first delegation specifier is base_class, rest are implement_interface - Use self._captures() instead of direct query.captures() calls - Handle constructor_invocation in delegation specifiers (e.g. Shape(...)) - Fix source_analyzer second_pass: use entity.resolved_symbols instead of iterating raw symbol nodes, so graph edges use resolved entity IDs - Fix resolve_method: use 'identifier' instead of 'simple_identifier' - Add unit tests and Kotlin fixture (11 tests, all passing) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 6442af8 commit 20d093b

File tree

5 files changed

+222
-49
lines changed

5 files changed

+222
-49
lines changed

api/analyzers/kotlin/analyzer.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
2-
from ...entities import *
2+
from ...entities.entity import Entity
3+
from ...entities.file import File
34
from typing import Optional
45
from ..analyzer import AbstractAnalyzer
56

@@ -38,15 +39,9 @@ def get_entity_label(self, node: Node) -> str:
3839
raise ValueError(f"Unknown entity type: {node.type}")
3940

4041
def get_entity_name(self, node: Node) -> str:
41-
if node.type in ['class_declaration', 'object_declaration']:
42-
# Find the type_identifier child
43-
for child in node.children:
44-
if child.type == 'type_identifier':
45-
return child.text.decode('utf-8')
46-
elif node.type == 'function_declaration':
47-
# Find the simple_identifier child
42+
if node.type in ['class_declaration', 'object_declaration', 'function_declaration']:
4843
for child in node.children:
49-
if child.type == 'simple_identifier':
44+
if child.type == 'identifier':
5045
return child.text.decode('utf-8')
5146
raise ValueError(f"Cannot extract name from entity type: {node.type}")
5247

@@ -64,52 +59,58 @@ def get_entity_docstring(self, node: Node) -> Optional[str]:
6459
def get_entity_types(self) -> list[str]:
6560
return ['class_declaration', 'object_declaration', 'function_declaration']
6661

62+
def _get_delegation_types(self, entity: Entity) -> list:
63+
"""Extract type identifiers from delegation specifiers in order."""
64+
types = []
65+
for child in entity.node.children:
66+
if child.type == 'delegation_specifiers':
67+
for spec in child.children:
68+
if spec.type == 'delegation_specifier':
69+
for sub in spec.children:
70+
if sub.type == 'constructor_invocation':
71+
for s in sub.children:
72+
if s.type == 'user_type':
73+
for id_node in s.children:
74+
if id_node.type == 'identifier':
75+
types.append(id_node)
76+
elif sub.type == 'user_type':
77+
for id_node in sub.children:
78+
if id_node.type == 'identifier':
79+
types.append(id_node)
80+
return types
81+
6782
def add_symbols(self, entity: Entity) -> None:
6883
if entity.node.type == 'class_declaration':
69-
# Find superclass (extends)
70-
superclass_query = self.language.query("(delegation_specifier (user_type (type_identifier) @superclass))")
71-
superclass_captures = superclass_query.captures(entity.node)
72-
if 'superclass' in superclass_captures:
73-
for superclass in superclass_captures['superclass']:
74-
entity.add_symbol("base_class", superclass)
75-
76-
# Find interfaces (implements)
77-
# In Kotlin, both inheritance and interface implementation use the same syntax
78-
# We'll treat all as interfaces for now since Kotlin can only extend one class
79-
interface_query = self.language.query("(delegation_specifier (user_type (type_identifier) @interface))")
80-
interface_captures = interface_query.captures(entity.node)
81-
if 'interface' in interface_captures:
82-
for interface in interface_captures['interface']:
83-
entity.add_symbol("implement_interface", interface)
84+
types = self._get_delegation_types(entity)
85+
if types:
86+
# First one is the superclass (base_class)
87+
entity.add_symbol("base_class", types[0])
88+
# Remaining are interfaces
89+
for iface in types[1:]:
90+
entity.add_symbol("implement_interface", iface)
8491

8592
elif entity.node.type == 'object_declaration':
86-
# Objects can also have delegation specifiers
87-
interface_query = self.language.query("(delegation_specifier (user_type (type_identifier) @interface))")
88-
interface_captures = interface_query.captures(entity.node)
89-
if 'interface' in interface_captures:
90-
for interface in interface_captures['interface']:
91-
entity.add_symbol("implement_interface", interface)
93+
types = self._get_delegation_types(entity)
94+
for t in types:
95+
entity.add_symbol("implement_interface", t)
9296

9397
elif entity.node.type == 'function_declaration':
9498
# Find function calls
95-
query = self.language.query("(call_expression) @reference.call")
96-
captures = query.captures(entity.node)
99+
captures = self._captures("(call_expression) @reference.call", entity.node)
97100
if 'reference.call' in captures:
98101
for caller in captures['reference.call']:
99102
entity.add_symbol("call", caller)
100103

101104
# Find parameters with types
102-
param_query = self.language.query("(parameter type: (user_type (type_identifier) @parameter))")
103-
param_captures = param_query.captures(entity.node)
104-
if 'parameter' in param_captures:
105-
for parameter in param_captures['parameter']:
105+
captures = self._captures("(parameter (user_type (identifier) @parameter))", entity.node)
106+
if 'parameter' in captures:
107+
for parameter in captures['parameter']:
106108
entity.add_symbol("parameters", parameter)
107109

108110
# Find return type
109-
return_type_query = self.language.query("(function_declaration type: (user_type (type_identifier) @return_type))")
110-
return_type_captures = return_type_query.captures(entity.node)
111-
if 'return_type' in return_type_captures:
112-
for return_type in return_type_captures['return_type']:
111+
captures = self._captures("(function_declaration (user_type (identifier) @return_type))", entity.node)
112+
if 'return_type' in captures:
113+
for return_type in captures['return_type']:
113114
entity.add_symbol("return_type", return_type)
114115

115116
def is_dependency(self, file_path: str) -> bool:
@@ -134,7 +135,7 @@ def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_
134135
if node.type == 'call_expression':
135136
# Find the identifier being called
136137
for child in node.children:
137-
if child.type in ['simple_identifier', 'navigation_expression']:
138+
if child.type in ['identifier', 'navigation_expression']:
138139
for file, resolved_node in self.resolve(files, lsp, file_path, path, child):
139140
method_dec = self.find_parent(resolved_node, ['function_declaration', 'class_declaration', 'object_declaration'])
140141
if method_dec and method_dec.type in ['class_declaration', 'object_declaration']:

api/analyzers/source_analyzer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,20 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None:
168168
logging.info(f'Processing file ({i + 1}/{files_len}): {file_path}')
169169
for _, entity in file.entities.items():
170170
entity.resolved_symbol(lambda key, symbol, fp=file_path: analyzers[fp.suffix].resolve_symbol(self.files, lsps[fp.suffix], fp, path, key, symbol))
171-
for key, symbols in entity.symbols.items():
172-
for symbol in symbols:
171+
for key, resolved_set in entity.resolved_symbols.items():
172+
for resolved in resolved_set:
173173
if key == "base_class":
174-
graph.connect_entities("EXTENDS", entity.id, symbol.id)
174+
graph.connect_entities("EXTENDS", entity.id, resolved.id)
175175
elif key == "implement_interface":
176-
graph.connect_entities("IMPLEMENTS", entity.id, symbol.id)
176+
graph.connect_entities("IMPLEMENTS", entity.id, resolved.id)
177177
elif key == "extend_interface":
178-
graph.connect_entities("EXTENDS", entity.id, symbol.id)
178+
graph.connect_entities("EXTENDS", entity.id, resolved.id)
179179
elif key == "call":
180-
graph.connect_entities("CALLS", entity.id, symbol.id)
180+
graph.connect_entities("CALLS", entity.id, resolved.id)
181181
elif key == "return_type":
182-
graph.connect_entities("RETURNS", entity.id, symbol.id)
182+
graph.connect_entities("RETURNS", entity.id, resolved.id)
183183
elif key == "parameters":
184-
graph.connect_entities("PARAMETERS", entity.id, symbol.id)
184+
graph.connect_entities("PARAMETERS", entity.id, resolved.id)
185185

186186
def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None:
187187
self.first_pass(path, files, [], graph)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/**
2+
* A base interface for logging
3+
*/
4+
interface Logger {
5+
fun log(message: String)
6+
}
7+
8+
/**
9+
* Base class for shapes
10+
*/
11+
open class Shape(val name: String) {
12+
open fun area(): Double = 0.0
13+
}
14+
15+
class Circle(val radius: Double) : Shape("circle"), Logger {
16+
override fun area(): Double {
17+
return Math.PI * radius * radius
18+
}
19+
20+
override fun log(message: String) {
21+
println(message)
22+
}
23+
}
24+
25+
fun calculateTotal(shapes: List<Shape>): Double {
26+
var total = 0.0
27+
for (shape in shapes) {
28+
total += shape.area()
29+
}
30+
return total
31+
}
32+
33+
object AppConfig : Logger {
34+
val version = "1.0"
35+
36+
override fun log(message: String) {
37+
println("[$version] $message")
38+
}
39+
}

tests/test_kotlin_analyzer.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Tests for the Kotlin analyzer - extraction only (no DB required)."""
2+
3+
import unittest
4+
from pathlib import Path
5+
6+
from api.analyzers.kotlin.analyzer import KotlinAnalyzer
7+
from api.entities.entity import Entity
8+
from api.entities.file import File
9+
10+
11+
def _entity_name(analyzer, entity):
12+
"""Get the name of an entity using the analyzer."""
13+
return analyzer.get_entity_name(entity.node)
14+
15+
16+
class TestKotlinAnalyzer(unittest.TestCase):
17+
@classmethod
18+
def setUpClass(cls):
19+
cls.analyzer = KotlinAnalyzer()
20+
source_dir = Path(__file__).parent / "source_files" / "kotlin"
21+
cls.sample_path = source_dir / "sample.kt"
22+
source = cls.sample_path.read_bytes()
23+
tree = cls.analyzer.parser.parse(source)
24+
cls.file = File(cls.sample_path, tree)
25+
26+
# Walk AST and extract entities
27+
types = cls.analyzer.get_entity_types()
28+
stack = [tree.root_node]
29+
while stack:
30+
node = stack.pop()
31+
if node.type in types:
32+
entity = Entity(node)
33+
cls.analyzer.add_symbols(entity)
34+
cls.file.add_entity(entity)
35+
stack.extend(node.children)
36+
else:
37+
stack.extend(node.children)
38+
39+
def _entity_names(self):
40+
return [_entity_name(self.analyzer, e) for e in self.file.entities.values()]
41+
42+
def test_entity_types(self):
43+
"""Analyzer should recognise Kotlin entity types."""
44+
self.assertEqual(
45+
self.analyzer.get_entity_types(),
46+
['class_declaration', 'object_declaration', 'function_declaration'],
47+
)
48+
49+
def test_class_extraction(self):
50+
"""Classes should be extracted."""
51+
names = self._entity_names()
52+
self.assertIn("Shape", names)
53+
self.assertIn("Circle", names)
54+
55+
def test_interface_extraction(self):
56+
"""Interfaces should be extracted."""
57+
names = self._entity_names()
58+
self.assertIn("Logger", names)
59+
60+
def test_object_extraction(self):
61+
"""Object declarations should be extracted."""
62+
names = self._entity_names()
63+
self.assertIn("AppConfig", names)
64+
65+
def test_function_extraction(self):
66+
"""Top-level functions should be extracted."""
67+
names = self._entity_names()
68+
self.assertIn("calculateTotal", names)
69+
70+
def test_class_label(self):
71+
"""Classes should get the 'Class' label."""
72+
for entity in self.file.entities.values():
73+
if _entity_name(self.analyzer, entity) in ("Shape", "Circle"):
74+
self.assertEqual(self.analyzer.get_entity_label(entity.node), "Class")
75+
76+
def test_interface_label(self):
77+
"""Interfaces should get the 'Interface' label."""
78+
for entity in self.file.entities.values():
79+
if _entity_name(self.analyzer, entity) == "Logger":
80+
self.assertEqual(self.analyzer.get_entity_label(entity.node), "Interface")
81+
82+
def test_object_label(self):
83+
"""Object declarations should get the 'Object' label."""
84+
for entity in self.file.entities.values():
85+
if _entity_name(self.analyzer, entity) == "AppConfig":
86+
self.assertEqual(self.analyzer.get_entity_label(entity.node), "Object")
87+
88+
def test_base_class_symbol(self):
89+
"""Circle should have Shape as base_class (first delegation specifier)."""
90+
for entity in self.file.entities.values():
91+
if _entity_name(self.analyzer, entity) == "Circle":
92+
base_names = [
93+
s.text.decode("utf-8")
94+
for s in entity.symbols.get("base_class", [])
95+
]
96+
self.assertIn("Shape", base_names)
97+
98+
def test_interface_implementation(self):
99+
"""Circle should implement Logger (second delegation specifier)."""
100+
for entity in self.file.entities.values():
101+
if _entity_name(self.analyzer, entity) == "Circle":
102+
iface_names = [
103+
s.text.decode("utf-8")
104+
for s in entity.symbols.get("implement_interface", [])
105+
]
106+
self.assertIn("Logger", iface_names)
107+
108+
def test_is_dependency(self):
109+
"""Build/gradle paths should be flagged as dependencies."""
110+
self.assertTrue(self.analyzer.is_dependency("project/build/classes/Main.kt"))
111+
self.assertTrue(self.analyzer.is_dependency("project/.gradle/cache/lib.kt"))
112+
self.assertFalse(self.analyzer.is_dependency("src/main/kotlin/App.kt"))
113+
114+
115+
if __name__ == "__main__":
116+
unittest.main()

uv.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)