Skip to content

Commit 1df9a1a

Browse files
committed
refactor: extract shared types to break circular import
Move OptimizedCandidateSource and BatchRefiner models from models.py to shared_types.py to avoid a circular dependency between models.py and function_optimizer.py.
1 parent 0a2ec48 commit 1df9a1a

3 files changed

Lines changed: 85 additions & 28 deletions

File tree

codeflash/models/function_types.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,9 @@
1212
from pydantic import Field
1313
from pydantic.dataclasses import dataclass
1414

15+
from codeflash.models.shared_types import FunctionParent
1516

16-
@dataclass(frozen=True)
17-
class FunctionParent:
18-
name: str
19-
type: str
20-
21-
def __str__(self) -> str:
22-
return f"{self.type}:{self.name}"
17+
__all__ = ["FunctionParent", "FunctionToOptimize"]
2318

2419

2520
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})

codeflash/models/models.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
1515
from pydantic.dataclasses import dataclass
1616

17+
from codeflash.models.shared_types import OptimizedCandidateSource
1718
from codeflash.models.test_type import TestType
1819

1920
if TYPE_CHECKING:
@@ -50,6 +51,23 @@ class AIServiceRefinerRequest:
5051
additional_context_files: dict[str, str] | None = None # {filepath: content} for imported modules
5152

5253

54+
@dataclass(frozen=True)
55+
class AIServiceBatchRefinerCandidate:
56+
optimization_id: str
57+
optimized_source_code: str
58+
optimized_explanation: str
59+
optimized_code_runtime: int
60+
original_code_runtime: int
61+
speedup: str
62+
optimized_line_profiler_results: str
63+
64+
65+
@dataclass(frozen=True)
66+
class AIServiceBatchRefinerRequest:
67+
shared_context: dict[str, Any]
68+
candidates: list[dict[str, Any]]
69+
70+
5371
# this should be possible to auto serialize
5472
@dataclass(frozen=True)
5573
class AdaptiveOptimizedCandidate:
@@ -298,11 +316,11 @@ def flat(self) -> str:
298316
299317
"""
300318
if self._cache.get("flat") is not None:
301-
return self._cache["flat"]
319+
return cast("str", self._cache["flat"])
302320
self._cache["flat"] = "\n".join(
303321
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
304322
)
305-
return self._cache["flat"]
323+
return cast("str", self._cache["flat"])
306324

307325
@property
308326
def markdown(self) -> str:
@@ -332,7 +350,7 @@ def file_to_path(self) -> dict[str, str]:
332350
333351
"""
334352
try:
335-
return self._cache["file_to_path"]
353+
return cast("dict[str, str]", self._cache["file_to_path"])
336354
except KeyError:
337355
mapping = {str(code_string.file_path): code_string.code for code_string in self.code_strings}
338356
self._cache["file_to_path"] = mapping
@@ -494,7 +512,7 @@ def _normalize_path_for_comparison(path: Path) -> str:
494512
# Only lowercase on Windows where filesystem is case-insensitive
495513
return resolved.lower() if sys.platform == "win32" else resolved
496514

497-
def __iter__(self) -> Iterator[TestFile]:
515+
def __iter__(self) -> Iterator[TestFile]: # type: ignore[override]
498516
return iter(self.test_files)
499517

500518
def __len__(self) -> int:
@@ -514,9 +532,9 @@ class CandidateEvaluationContext:
514532
optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
515533
is_correct: dict[str, bool] = Field(default_factory=dict)
516534
optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
517-
ast_code_to_id: dict = Field(default_factory=dict)
535+
ast_code_to_id: dict[str, Any] = Field(default_factory=dict)
518536
optimizations_post: dict[str, str] = Field(default_factory=dict)
519-
valid_optimizations: list = Field(default_factory=list)
537+
valid_optimizations: list[Any] = Field(default_factory=list)
520538

521539
def record_failed_candidate(self, optimization_id: str) -> None:
522540
"""Record results for a failed candidate."""
@@ -543,7 +561,7 @@ def handle_duplicate_candidate(
543561
# Copy results from the previous evaluation (use .get() in case past_opt_id was registered
544562
# but never benchmarked due to an unhandled exception in process_single_candidate)
545563
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios.get(past_opt_id)
546-
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id)
564+
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id, False)
547565
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes.get(past_opt_id)
548566

549567
# Line profiler results only available for successful runs
@@ -592,15 +610,6 @@ class TestsInFile:
592610
test_type: TestType
593611

594612

595-
class OptimizedCandidateSource(str, Enum):
596-
OPTIMIZE = "OPTIMIZE"
597-
OPTIMIZE_LP = "OPTIMIZE_LP"
598-
REFINE = "REFINE"
599-
REPAIR = "REPAIR"
600-
ADAPTIVE = "ADAPTIVE"
601-
JIT_REWRITE = "JIT_REWRITE"
602-
603-
604613
@dataclass(frozen=True)
605614
class OptimizedCandidate:
606615
source_code: CodeStringsMarkdown
@@ -631,7 +640,7 @@ class OriginalCodeBaseline(BaseModel):
631640
behavior_test_results: TestResults
632641
benchmarking_test_results: TestResults
633642
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
634-
line_profile_results: dict
643+
line_profile_results: dict[str, Any]
635644
runtime: int
636645
coverage_results: Optional[CoverageData]
637646
async_throughput: Optional[int] = None
@@ -794,6 +803,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]:
794803
)
795804

796805
if self.test_class_name:
806+
assert self.test_function_name is not None
797807
for stmt in module_node.body:
798808
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
799809
func_node = self.find_func_in_class(stmt, self.test_function_name)
@@ -884,7 +894,7 @@ def group_by_benchmarks(
884894
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
885895
from codeflash.code_utils.code_utils import module_name_from_file_path
886896

887-
test_results_by_benchmark = defaultdict(TestResults)
897+
test_results_by_benchmark: defaultdict[BenchmarkKey, TestResults] = defaultdict(TestResults)
888898
benchmark_module_path = {}
889899
for benchmark_key in benchmark_keys:
890900
benchmark_module_path[benchmark_key] = module_name_from_file_path(
@@ -1015,7 +1025,7 @@ def effective_loop_count(self) -> int:
10151025
return max(loop_indices) if loop_indices else 0
10161026

10171027
def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
1018-
map_gen_test_file_to_no_of_tests = Counter()
1028+
map_gen_test_file_to_no_of_tests: Counter[Path] = Counter()
10191029
for gen_test_result in self.test_results:
10201030
if (
10211031
gen_test_result.test_type == TestType.GENERATED_REGRESSION
@@ -1024,7 +1034,7 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa
10241034
map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1
10251035
return map_gen_test_file_to_no_of_tests
10261036

1027-
def __iter__(self) -> Iterator[FunctionTestInvocation]:
1037+
def __iter__(self) -> Iterator[FunctionTestInvocation]: # type: ignore[override]
10281038
return iter(self.test_results)
10291039

10301040
def __len__(self) -> int:
@@ -1051,7 +1061,7 @@ def __eq__(self, other: object) -> bool:
10511061
if len(self) != len(other):
10521062
return False
10531063
original_recursion_limit = sys.getrecursionlimit()
1054-
cast("TestResults", other)
1064+
assert isinstance(other, TestResults)
10551065
for test_result in self:
10561066
other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id)
10571067
if other_test_result is None:

codeflash/models/shared_types.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Shared types for cross-repo use between codeflash CLI and codeflash-internal server.
2+
3+
This module defines types that are duplicated or shared between the client (CLI)
4+
and the server. Centralizing them here allows both sides to import from a single
5+
source of truth.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from enum import Enum
11+
12+
from pydantic.dataclasses import dataclass
13+
14+
# --- Enums ---
15+
16+
17+
class OptimizedCandidateSource(str, Enum):
18+
OPTIMIZE = "OPTIMIZE"
19+
OPTIMIZE_LP = "OPTIMIZE_LP"
20+
REFINE = "REFINE"
21+
REPAIR = "REPAIR"
22+
ADAPTIVE = "ADAPTIVE"
23+
JIT_REWRITE = "JIT_REWRITE"
24+
25+
26+
# --- Models ---
27+
28+
29+
@dataclass(frozen=True)
30+
class FunctionParent:
31+
name: str
32+
type: str
33+
34+
def __str__(self) -> str:
35+
return f"{self.type}:{self.name}"
36+
37+
38+
# --- Constants: Language identifiers ---
39+
40+
LANGUAGE_PYTHON = "python"
41+
LANGUAGE_JAVASCRIPT = "javascript"
42+
LANGUAGE_TYPESCRIPT = "typescript"
43+
LANGUAGE_JAVA = "java"
44+
45+
SUPPORTED_LANGUAGES = frozenset({LANGUAGE_PYTHON, LANGUAGE_JAVASCRIPT, LANGUAGE_TYPESCRIPT, LANGUAGE_JAVA})
46+
47+
# --- Constants: Test type names ---
48+
49+
TEST_TYPE_EXISTING_UNIT = "existing_unit_test"
50+
TEST_TYPE_GENERATED_REGRESSION = "generated_regression"
51+
TEST_TYPE_REPLAY = "replay_test"
52+
TEST_TYPE_CONCOLIC_COVERAGE = "concolic_coverage_test"

0 commit comments

Comments
 (0)