Skip to content

Commit 436d642

Browse files
committed
perf: defer libcst, Rich, comparator imports in models.py
Move libcst, rich.tree.Tree, console, comparator, code_utils, registry, lsp.helpers, and LspMarkdownMessage from module-level to the methods that use them. Only pydantic and TestType remain at module level (needed for class definitions). models.py import: 633ms → 125ms on Azure Standard_D4s_v5.
1 parent 88babfe commit 436d642

1 file changed

Lines changed: 39 additions & 22 deletions

File tree

codeflash/models/models.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,23 @@
11
from __future__ import annotations
22

3-
from collections import Counter, defaultdict
4-
from functools import lru_cache
5-
from typing import TYPE_CHECKING
6-
7-
import libcst as cst
8-
from rich.tree import Tree
9-
10-
from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log
11-
from codeflash.languages.registry import get_language_support
12-
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table
13-
from codeflash.lsp.lsp_message import LspMarkdownMessage
14-
from codeflash.models.test_type import TestType
15-
16-
if TYPE_CHECKING:
17-
from collections.abc import Iterator
18-
193
import enum
204
import re
215
import sys
6+
from collections import Counter, defaultdict
227
from collections.abc import Collection
238
from enum import Enum, IntEnum
9+
from functools import lru_cache
2410
from pathlib import Path
2511
from re import Pattern
26-
from typing import Any, NamedTuple, Optional, cast
12+
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, cast
2713

2814
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
2915
from pydantic.dataclasses import dataclass
3016

31-
from codeflash.cli_cmds.console import console, logger
32-
from codeflash.code_utils.code_utils import diff_length, module_name_from_file_path, validate_python_code
33-
from codeflash.code_utils.env_utils import is_end_to_end
34-
from codeflash.verification.comparator import comparator
17+
from codeflash.models.test_type import TestType
18+
19+
if TYPE_CHECKING:
20+
from collections.abc import Iterator
3521

3622

3723
@dataclass(frozen=True)
@@ -254,6 +240,8 @@ class CodeString(BaseModel):
254240
def validate_code_syntax(self) -> CodeString:
255241
"""Validate code syntax for the specified language."""
256242
if self.language == "python":
243+
from codeflash.code_utils.code_utils import validate_python_code # noqa: PLC0415
244+
257245
validate_python_code(self.code)
258246
else:
259247
from codeflash.languages.registry import get_language_support
@@ -267,6 +255,8 @@ def validate_code_syntax(self) -> CodeString:
267255

268256
def get_comment_prefix(file_path: Path) -> str:
269257
"""Get the comment prefix for a given language."""
258+
from codeflash.languages.registry import get_language_support # noqa: PLC0415
259+
270260
support = get_language_support(file_path)
271261
return support.comment_prefix
272262

@@ -565,6 +555,8 @@ def handle_duplicate_candidate(
565555
self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown
566556

567557
# Update to shorter code if this candidate has a shorter diff
558+
from codeflash.code_utils.code_utils import diff_length # noqa: PLC0415
559+
568560
new_diff_len = diff_length(candidate.source_code.flat, original_flat_code)
569561
if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]:
570562
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
@@ -574,6 +566,8 @@ def register_new_candidate(
574566
self, normalized_code: str, candidate: OptimizedCandidate, original_flat_code: str
575567
) -> None:
576568
"""Register a new candidate that hasn't been seen before."""
569+
from codeflash.code_utils.code_utils import diff_length # noqa: PLC0415
570+
577571
self.ast_code_to_id[normalized_code] = {
578572
"optimization_id": candidate.optimization_id,
579573
"shorter_source_code": candidate.source_code,
@@ -668,7 +662,10 @@ def build_message(self) -> str:
668662
return f"{self.coverage:.1f}%"
669663

670664
def log_coverage(self) -> None:
671-
from rich.tree import Tree
665+
from rich.tree import Tree # noqa: PLC0415
666+
667+
from codeflash.cli_cmds.console import console, logger # noqa: PLC0415
668+
from codeflash.code_utils.env_utils import is_end_to_end # noqa: PLC0415
672669

673670
tree = Tree("Test Coverage Results")
674671
tree.add(f"Main Function: {self.main_func_coverage.name}: {self.coverage:.2f}%")
@@ -769,12 +766,16 @@ def test_fn_qualified_name(self) -> str:
769766
)
770767

771768
def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]:
769+
import libcst as cst # noqa: PLC0415
770+
772771
for stmt in class_node.body.body:
773772
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name:
774773
return stmt
775774
return None
776775

777776
def get_src_code(self, test_path: Path) -> Optional[str]:
777+
import libcst as cst # noqa: PLC0415
778+
778779
if not test_path.exists():
779780
return None
780781
try:
@@ -856,6 +857,8 @@ def add(self, function_test_invocation: FunctionTestInvocation) -> None:
856857
unique_id = function_test_invocation.unique_invocation_loop_id
857858
test_result_idx = self.test_result_idx
858859
if unique_id in test_result_idx:
860+
from codeflash.cli_cmds.console import DEBUG_MODE, logger # noqa: PLC0415
861+
859862
if DEBUG_MODE:
860863
logger.warning(f"Test result with id {unique_id} already exists. SKIPPING")
861864
return
@@ -876,6 +879,8 @@ def group_by_benchmarks(
876879
self, benchmark_keys: list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path
877880
) -> dict[BenchmarkKey, TestResults]:
878881
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
882+
from codeflash.code_utils.code_utils import module_name_from_file_path # noqa: PLC0415
883+
879884
test_results_by_benchmark = defaultdict(TestResults)
880885
benchmark_module_path = {}
881886
for benchmark_key in benchmark_keys:
@@ -929,9 +934,17 @@ def report_to_string(report: dict[TestType, dict[str, int]]) -> str:
929934

930935
@staticmethod
931936
def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
937+
from rich.tree import Tree # noqa: PLC0415
938+
939+
from codeflash.lsp.helpers import is_LSP_enabled # noqa: PLC0415
940+
932941
tree = Tree(title)
933942

934943
if is_LSP_enabled():
944+
from codeflash.cli_cmds.console import lsp_log # noqa: PLC0415
945+
from codeflash.lsp.helpers import report_to_markdown_table # noqa: PLC0415
946+
from codeflash.lsp.lsp_message import LspMarkdownMessage # noqa: PLC0415
947+
935948
# Build markdown table
936949
markdown = report_to_markdown_table(report, title)
937950
lsp_log(LspMarkdownMessage(markdown=markdown))
@@ -946,6 +959,8 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
946959
return tree
947960

948961
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
962+
from codeflash.cli_cmds.console import logger # noqa: PLC0415
963+
949964
# Efficient single traversal, directly accumulating into a dict.
950965
# can track mins here and only sums can be return in total_passed_runtime
951966
by_id: dict[InvocationId, list[int]] = {}
@@ -1025,6 +1040,8 @@ def __bool__(self) -> bool:
10251040
return bool(self.test_results)
10261041

10271042
def __eq__(self, other: object) -> bool:
1043+
from codeflash.verification.comparator import comparator # noqa: PLC0415
1044+
10281045
# Unordered comparison
10291046
if type(self) is not type(other):
10301047
return False

0 commit comments

Comments
 (0)