Skip to content

Commit 22599e2

Browse files
committed
fix: resolve mypy strict errors in models.py
1 parent 589a874 commit 22599e2

1 file changed

Lines changed: 21 additions & 18 deletions

File tree

codeflash/models/models.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
from functools import lru_cache
1010
from pathlib import Path
1111
from re import Pattern
12-
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, cast
12+
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
1313

1414
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
1515
from pydantic.dataclasses import dataclass
1616

1717
from codeflash.models.test_type import TestType
1818

1919
if TYPE_CHECKING:
20-
from collections.abc import Iterator
20+
from collections.abc import Generator
2121

2222
import libcst as cst
2323
from rich.tree import Tree
@@ -298,11 +298,13 @@ def flat(self) -> str:
298298
299299
"""
300300
if self._cache.get("flat") is not None:
301-
return self._cache["flat"]
302-
self._cache["flat"] = "\n".join(
301+
result: str = self._cache["flat"]
302+
return result
303+
flat: str = "\n".join(
303304
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
304305
)
305-
return self._cache["flat"]
306+
self._cache["flat"] = flat
307+
return flat
306308

307309
@property
308310
def markdown(self) -> str:
@@ -332,7 +334,8 @@ def file_to_path(self) -> dict[str, str]:
332334
333335
"""
334336
try:
335-
return self._cache["file_to_path"]
337+
cached: dict[str, str] = self._cache["file_to_path"]
338+
return cached
336339
except KeyError:
337340
mapping = {str(code_string.file_path): code_string.code for code_string in self.code_strings}
338341
self._cache["file_to_path"] = mapping
@@ -497,8 +500,8 @@ def _normalize_path_for_comparison(path: Path) -> str:
497500
# Only lowercase on Windows where filesystem is case-insensitive
498501
return resolved.lower() if sys.platform == "win32" else resolved
499502

500-
def __iter__(self) -> Iterator[TestFile]:
501-
return iter(self.test_files)
503+
def __iter__(self) -> Generator[Any, None, None]: # noqa: PYI058
504+
yield from self.test_files
502505

503506
def __len__(self) -> int:
504507
return len(self.test_files)
@@ -517,9 +520,9 @@ class CandidateEvaluationContext:
517520
optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
518521
is_correct: dict[str, bool] = Field(default_factory=dict)
519522
optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
520-
ast_code_to_id: dict = Field(default_factory=dict)
523+
ast_code_to_id: dict[str, Any] = Field(default_factory=dict)
521524
optimizations_post: dict[str, str] = Field(default_factory=dict)
522-
valid_optimizations: list = Field(default_factory=list)
525+
valid_optimizations: list[Any] = Field(default_factory=list)
523526

524527
def record_failed_candidate(self, optimization_id: str) -> None:
525528
"""Record results for a failed candidate."""
@@ -546,7 +549,7 @@ def handle_duplicate_candidate(
546549
# Copy results from the previous evaluation (use .get() in case past_opt_id was registered
547550
# but never benchmarked due to an unhandled exception in process_single_candidate)
548551
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios.get(past_opt_id)
549-
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id)
552+
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id, False)
550553
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes.get(past_opt_id)
551554

552555
# Line profiler results only available for successful runs
@@ -634,7 +637,7 @@ class OriginalCodeBaseline(BaseModel):
634637
behavior_test_results: TestResults
635638
benchmarking_test_results: TestResults
636639
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
637-
line_profile_results: dict
640+
line_profile_results: dict[str, Any]
638641
runtime: int
639642
coverage_results: Optional[CoverageData]
640643
async_throughput: Optional[int] = None
@@ -796,7 +799,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]:
796799
f"// Testing function: {self.function_getting_tested}"
797800
)
798801

799-
if self.test_class_name:
802+
if self.test_class_name and self.test_function_name:
800803
for stmt in module_node.body:
801804
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
802805
func_node = self.find_func_in_class(stmt, self.test_function_name)
@@ -887,7 +890,7 @@ def group_by_benchmarks(
887890
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
888891
from codeflash.code_utils.code_utils import module_name_from_file_path
889892

890-
test_results_by_benchmark = defaultdict(TestResults)
893+
test_results_by_benchmark: defaultdict[BenchmarkKey, TestResults] = defaultdict(TestResults)
891894
benchmark_module_path = {}
892895
for benchmark_key in benchmark_keys:
893896
benchmark_module_path[benchmark_key] = module_name_from_file_path(
@@ -1018,7 +1021,7 @@ def effective_loop_count(self) -> int:
10181021
return max(loop_indices) if loop_indices else 0
10191022

10201023
def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
1021-
map_gen_test_file_to_no_of_tests = Counter()
1024+
map_gen_test_file_to_no_of_tests: Counter[Path] = Counter()
10221025
for gen_test_result in self.test_results:
10231026
if (
10241027
gen_test_result.test_type == TestType.GENERATED_REGRESSION
@@ -1027,8 +1030,8 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa
10271030
map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1
10281031
return map_gen_test_file_to_no_of_tests
10291032

1030-
def __iter__(self) -> Iterator[FunctionTestInvocation]:
1031-
return iter(self.test_results)
1033+
def __iter__(self) -> Generator[Any, None, None]: # noqa: PYI058
1034+
yield from self.test_results
10321035

10331036
def __len__(self) -> int:
10341037
return len(self.test_results)
@@ -1054,7 +1057,7 @@ def __eq__(self, other: object) -> bool:
10541057
if len(self) != len(other):
10551058
return False
10561059
original_recursion_limit = sys.getrecursionlimit()
1057-
cast("TestResults", other)
1060+
assert isinstance(other, TestResults)
10581061
for test_result in self:
10591062
other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id)
10601063
if other_test_result is None:

0 commit comments

Comments
 (0)