1- import json
1+ from __future__ import annotations
2+
23import pstats
34import sqlite3
45from copy import copy
56from pathlib import Path
7+ from typing import Any , TextIO
68
79from codeflash .cli_cmds .console import logger
810
911
1012class ProfileStats (pstats .Stats ):
13+ # Attributes set by pstats.Stats.init() — stubs don't expose them
14+ files : list [str ]
15+ stream : TextIO
16+ top_level : set [tuple [str , int , str ]]
17+ total_calls : int
18+ prim_calls : int
19+ total_tt : float
20+ max_name_len : int
21+ fcn_list : list [tuple [str , int , str ]] | None
22+ sort_arg_dict : dict [str , tuple [Any , ...]]
23+ all_callees : dict [tuple [str , int , str ], dict [tuple [str , int , str ], tuple [int , int , float , float ]]] | None
24+ stats : dict [tuple [str , int , str ], tuple [int , int , int | float , int | float , dict [Any , Any ]]]
25+
1126 def __init__ (self , trace_file_path : str , time_unit : str = "ns" ) -> None :
1227 assert Path (trace_file_path ).is_file (), f"Trace file { trace_file_path } does not exist"
1328 assert time_unit in {"ns" , "us" , "ms" , "s" }, f"Invalid time unit { time_unit } "
1429 self .trace_file_path = trace_file_path
1530 self .time_unit = time_unit
1631 logger .debug (hasattr (self , "create_stats" ))
17- super ().__init__ (copy (self ))
32+ super ().__init__ (copy (self )) # type: ignore[arg-type] # pstats uses duck-typed create_stats interface
1833
1934 def create_stats (self ) -> None :
2035 self .con = sqlite3 .connect (self .trace_file_path )
2136 cur = self .con .cursor ()
22- pdata = cur .execute ("SELECT * FROM pstats" ).fetchall ()
37+ pdata = cur .execute (
38+ "SELECT filename, line_number, function, class_name,"
39+ " call_count_nonrecursive, num_callers, total_time_ns, cumulative_time_ns"
40+ " FROM pstats"
41+ ).fetchall ()
2342 self .con .close ()
2443 time_conversion_factor = {"ns" : 1 , "us" : 1e3 , "ms" : 1e6 , "s" : 1e9 }[self .time_unit ]
2544 self .stats = {}
@@ -32,31 +51,18 @@ def create_stats(self) -> None:
3251 num_callers ,
3352 total_time_ns ,
3453 cumulative_time_ns ,
35- callers ,
3654 ) in pdata :
37- loaded_callers = json .loads (callers )
38- unmapped_callers = {}
39- for caller in loaded_callers :
40- caller_key = caller ["key" ]
41- if isinstance (caller_key , list ):
42- caller_key = tuple (caller_key )
43- elif not isinstance (caller_key , tuple ):
44- caller_key = (caller_key ,) if not isinstance (caller_key , (list , tuple )) else tuple (caller_key )
45- unmapped_callers [caller_key ] = caller ["value" ]
46-
47- # Create function key with class name if present (matching tracer.py format)
4855 function_name = f"{ class_name } .{ function } " if class_name else function
4956
5057 self .stats [(filename , line_number , function_name )] = (
5158 call_count_nonrecursive ,
5259 num_callers ,
5360 total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns ,
5461 cumulative_time_ns / time_conversion_factor if time_conversion_factor != 1 else cumulative_time_ns ,
55- unmapped_callers ,
62+ {} ,
5663 )
5764
58- def print_stats (self , * amount ) -> pstats .Stats : # noqa: ANN002
59- # Copied from pstats.Stats.print_stats and modified to print the correct time unit
65+ def print_stats (self , * amount : str | float ) -> ProfileStats :
6066 for filename in self .files :
6167 print (filename , file = self .stream )
6268 if self .files :
@@ -74,8 +80,8 @@ def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002
7480 _width , list_ = self .get_print_list (amount )
7581 if list_ :
7682 self .print_title ()
77- for func in list_ :
78- self .print_line (func )
83+ for fn in list_ :
84+ self .print_line (fn )
7985 print (file = self .stream )
8086 print (file = self .stream )
8187 return self
0 commit comments