Skip to content

Commit 7c12c6c

Browse files
committed
perf: skip unused callers JSON decode in ProfileStats
The callers BLOB column was fetched and JSON-decoded for every row during FunctionRanker initialization, but the resulting data was never accessed — FunctionRanker.load_function_stats() discards it as `_callers`. This eliminates O(n) JSON parsing at ranking startup. Also fixes all pre-existing mypy errors in this file by declaring pstats.Stats attributes that the type stubs don't expose.
1 parent 5356817 commit 7c12c6c

1 file changed

Lines changed: 26 additions & 20 deletions

File tree

codeflash/tracing/profile_stats.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,44 @@
1-
import json
1+
from __future__ import annotations
2+
23
import pstats
34
import sqlite3
45
from copy import copy
56
from pathlib import Path
7+
from typing import Any, TextIO
68

79
from codeflash.cli_cmds.console import logger
810

911

1012
class ProfileStats(pstats.Stats):
13+
# Attributes set by pstats.Stats.init() — stubs don't expose them
14+
files: list[str]
15+
stream: TextIO
16+
top_level: set[tuple[str, int, str]]
17+
total_calls: int
18+
prim_calls: int
19+
total_tt: float
20+
max_name_len: int
21+
fcn_list: list[tuple[str, int, str]] | None
22+
sort_arg_dict: dict[str, tuple[Any, ...]]
23+
all_callees: dict[tuple[str, int, str], dict[tuple[str, int, str], tuple[int, int, float, float]]] | None
24+
stats: dict[tuple[str, int, str], tuple[int, int, int | float, int | float, dict[Any, Any]]]
25+
1126
def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None:
1227
assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist"
1328
assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}"
1429
self.trace_file_path = trace_file_path
1530
self.time_unit = time_unit
1631
logger.debug(hasattr(self, "create_stats"))
17-
super().__init__(copy(self))
32+
super().__init__(copy(self)) # type: ignore[arg-type] # pstats uses duck-typed create_stats interface
1833

1934
def create_stats(self) -> None:
2035
self.con = sqlite3.connect(self.trace_file_path)
2136
cur = self.con.cursor()
22-
pdata = cur.execute("SELECT * FROM pstats").fetchall()
37+
pdata = cur.execute(
38+
"SELECT filename, line_number, function, class_name,"
39+
" call_count_nonrecursive, num_callers, total_time_ns, cumulative_time_ns"
40+
" FROM pstats"
41+
).fetchall()
2342
self.con.close()
2443
time_conversion_factor = {"ns": 1, "us": 1e3, "ms": 1e6, "s": 1e9}[self.time_unit]
2544
self.stats = {}
@@ -32,31 +51,18 @@ def create_stats(self) -> None:
3251
num_callers,
3352
total_time_ns,
3453
cumulative_time_ns,
35-
callers,
3654
) in pdata:
37-
loaded_callers = json.loads(callers)
38-
unmapped_callers = {}
39-
for caller in loaded_callers:
40-
caller_key = caller["key"]
41-
if isinstance(caller_key, list):
42-
caller_key = tuple(caller_key)
43-
elif not isinstance(caller_key, tuple):
44-
caller_key = (caller_key,) if not isinstance(caller_key, (list, tuple)) else tuple(caller_key)
45-
unmapped_callers[caller_key] = caller["value"]
46-
47-
# Create function key with class name if present (matching tracer.py format)
4855
function_name = f"{class_name}.{function}" if class_name else function
4956

5057
self.stats[(filename, line_number, function_name)] = (
5158
call_count_nonrecursive,
5259
num_callers,
5360
total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns,
5461
cumulative_time_ns / time_conversion_factor if time_conversion_factor != 1 else cumulative_time_ns,
55-
unmapped_callers,
62+
{},
5663
)
5764

58-
def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002
59-
# Copied from pstats.Stats.print_stats and modified to print the correct time unit
65+
def print_stats(self, *amount: str | float) -> ProfileStats:
6066
for filename in self.files:
6167
print(filename, file=self.stream)
6268
if self.files:
@@ -74,8 +80,8 @@ def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002
7480
_width, list_ = self.get_print_list(amount)
7581
if list_:
7682
self.print_title()
77-
for func in list_:
78-
self.print_line(func)
83+
for fn in list_:
84+
self.print_line(fn)
7985
print(file=self.stream)
8086
print(file=self.stream)
8187
return self

0 commit comments

Comments
 (0)