Skip to content

Commit a5aa75d

Browse files
fix typing issues
1 parent ea51f78 commit a5aa75d

5 files changed

Lines changed: 30 additions & 30 deletions

File tree

codeflash/languages/golang/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import logging
44
from typing import TYPE_CHECKING
55

6-
from codeflash.languages.base import CodeContext, HelperFunction, Language
6+
from codeflash.languages.base import CodeContext, HelperFunction
77
from codeflash.languages.golang.parser import GoAnalyzer
8+
from codeflash.languages.language_enum import Language
89

910
if TYPE_CHECKING:
1011
from pathlib import Path

codeflash/languages/golang/function_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
class GoFunctionOptimizer(FunctionOptimizer):
2828
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
2929
from codeflash.languages import get_language_support
30-
from codeflash.languages.base import Language
30+
from codeflash.languages.language_enum import Language
3131

3232
language = Language(self.function_to_optimize.language)
3333
lang_support = get_language_support(language)

codeflash/languages/golang/parse.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import logging
55
import re
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Any
77

88
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults
99

@@ -12,6 +12,7 @@
1212
from pathlib import Path
1313

1414
from codeflash.models.models import TestFiles
15+
from codeflash.models.test_type import TestType
1516
from codeflash.verification.verification_utils import TestConfig
1617

1718
logger = logging.getLogger(__name__)
@@ -29,7 +30,7 @@ def parse_go_test_output(
2930
test_json_path: Path,
3031
test_files: TestFiles,
3132
test_config: TestConfig,
32-
run_result: subprocess.CompletedProcess | None = None,
33+
run_result: subprocess.CompletedProcess[str] | None = None,
3334
) -> TestResults:
3435
test_results = TestResults()
3536

@@ -71,10 +72,11 @@ def parse_go_test_output(
7172
active[test_name] = _TestIteration(test_name=test_name, package=package)
7273
continue
7374

74-
it = active.get(test_name)
75-
if it is None:
76-
it = _TestIteration(test_name=test_name, package=package)
77-
active[test_name] = it
75+
maybe_it = active.get(test_name)
76+
if maybe_it is None:
77+
maybe_it = _TestIteration(test_name=test_name, package=package)
78+
active[test_name] = maybe_it
79+
it = maybe_it
7880

7981
if action == "output":
8082
output_text = event.get("Output", "")
@@ -109,9 +111,6 @@ def parse_go_test_output(
109111

110112
test_file_path = _resolve_test_file(it.test_name, it.package, test_files, base_dir)
111113
test_type = _resolve_test_type(test_file_path, test_files)
112-
if test_type is None:
113-
logger.debug("Skipping test %s: could not resolve test type", it.test_name)
114-
continue
115114

116115
test_results.add(
117116
FunctionTestInvocation(
@@ -157,23 +156,20 @@ def __init__(self, test_name: str, package: str) -> None:
157156
self.stdout: str = ""
158157

159158

160-
def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None) -> str:
159+
def _read_json_output(path: Path, run_result: subprocess.CompletedProcess[str] | None) -> str:
161160
try:
162161
content = path.read_text(encoding="utf-8")
163162
if content.strip():
164163
return content
165164
except Exception:
166165
pass
167166
if run_result is not None:
168-
stdout = run_result.stdout
169-
if isinstance(stdout, bytes):
170-
stdout = stdout.decode("utf-8", errors="replace")
171-
return stdout or ""
167+
return run_result.stdout or ""
172168
return ""
173169

174170

175-
def _parse_json_lines(content: str) -> list[dict]:
176-
events: list[dict] = []
171+
def _parse_json_lines(content: str) -> list[dict[str, Any]]:
172+
events: list[dict[str, Any]] = []
177173
for line in content.splitlines():
178174
line = line.strip()
179175
if not line:
@@ -199,7 +195,7 @@ def _resolve_test_file(test_name: str, package: str, test_files: TestFiles, base
199195
return base_dir / f"{test_name}.go"
200196

201197

202-
def _resolve_test_type(test_file_path: Path, test_files: TestFiles):
198+
def _resolve_test_type(test_file_path: Path, test_files: TestFiles) -> TestType:
203199
from codeflash.models.test_type import TestType
204200

205201
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)

codeflash/languages/golang/replacement.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from codeflash.languages.golang.parser import GoAnalyzer
77

88
if TYPE_CHECKING:
9+
import tree_sitter
10+
911
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
12+
from codeflash.languages.golang.parser import GoGlobalDeclaration
1013

1114
logger = logging.getLogger(__name__)
1215

@@ -102,7 +105,7 @@ def _merge_global_var_const(optimized_code: str, original_source: str, analyzer:
102105
return original_source
103106

104107
orig_decls = analyzer.find_global_declarations(original_source)
105-
orig_names_to_decl: dict[str, object] = {}
108+
orig_names_to_decl: dict[str, GoGlobalDeclaration] = {}
106109
for decl in orig_decls:
107110
for name in decl.names:
108111
orig_names_to_decl[name] = decl
@@ -131,7 +134,7 @@ def _merge_global_var_const(optimized_code: str, original_source: str, analyzer:
131134
return original_source
132135

133136

134-
def _replace_declaration_block(source: str, orig_decl: object, new_source_code: str) -> str:
137+
def _replace_declaration_block(source: str, orig_decl: GoGlobalDeclaration, new_source_code: str) -> str:
135138
lines = source.splitlines(keepends=True)
136139
start = orig_decl.starting_line - 1
137140
end = orig_decl.ending_line
@@ -186,19 +189,19 @@ def remove_test_functions(test_source: str, functions_to_remove: list[str], anal
186189
return "".join(lines)
187190

188191

189-
def _find_doc_comment_start(node: object) -> int | None:
190-
prev = getattr(node, "prev_named_sibling", None)
192+
def _find_doc_comment_start(node: tree_sitter.Node) -> int | None:
193+
prev = node.prev_named_sibling
191194
if prev is None:
192195
return None
193-
if getattr(prev, "type", None) != "comment":
196+
if prev.type != "comment":
194197
return None
195198
if prev.end_point.row + 1 != node.start_point.row:
196199
return None
197-
comment_start = prev.start_point.row + 1
200+
comment_start: int = prev.start_point.row + 1
198201
current = prev
199202
while True:
200-
earlier = getattr(current, "prev_named_sibling", None)
201-
if earlier is None or getattr(earlier, "type", None) != "comment":
203+
earlier = current.prev_named_sibling
204+
if earlier is None or earlier.type != "comment":
202205
break
203206
if earlier.end_point.row + 1 != current.start_point.row:
204207
break

codeflash/languages/golang/support.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING, Any
66

7+
from codeflash.languages.base import LanguageSupport
78
from codeflash.languages.golang.comparator import compare_test_results as _compare_results
89
from codeflash.languages.golang.config import detect_go_project, detect_go_version
910
from codeflash.languages.golang.context import extract_code_context as _extract_context
@@ -29,18 +30,17 @@
2930
DependencyResolver,
3031
FunctionFilterCriteria,
3132
HelperFunction,
32-
InvocationId,
3333
ReferenceInfo,
3434
TestInfo,
3535
)
3636
from codeflash.models.function_types import FunctionToOptimize
37-
from codeflash.models.models import GeneratedTestsList
37+
from codeflash.models.models import GeneratedTestsList, InvocationId
3838

3939
logger = logging.getLogger(__name__)
4040

4141

4242
@register_language
43-
class GoSupport:
43+
class GoSupport(LanguageSupport):
4444
def __init__(self) -> None:
4545
self._analyzer = GoAnalyzer()
4646
self._go_version: str | None = None

0 commit comments

Comments
 (0)