44from __future__ import annotations
55
66import ast
7+ from bisect import bisect_left , bisect_right
8+ from dataclasses import dataclass
79from pathlib import Path
810
911from .explain_contract import (
1820from .types import GroupItemsLike , GroupMapLike
1921
2022
23+ @dataclass (frozen = True , slots = True )
24+ class _StatementRecord :
25+ node : ast .stmt
26+ start_line : int
27+ end_line : int
28+ start_col : int
29+ end_col : int
30+ type_name : str
31+
32+
33+ _StatementIndex = tuple [tuple [_StatementRecord , ...], tuple [int , ...]]
34+
35+
2136def signature_parts (group_key : str ) -> list [str ]:
2237 return [part for part in group_key .split ("|" ) if part ]
2338
@@ -50,6 +65,53 @@ def parsed_file_tree(
5065 return tree
5166
5267
68+ def _build_statement_index (tree : ast .AST ) -> _StatementIndex :
69+ records = tuple (
70+ sorted (
71+ (
72+ _StatementRecord (
73+ node = node ,
74+ start_line = int (getattr (node , "lineno" , 0 )),
75+ end_line = int (getattr (node , "end_lineno" , 0 )),
76+ start_col = int (getattr (node , "col_offset" , 0 )),
77+ end_col = int (getattr (node , "end_col_offset" , 0 )),
78+ type_name = type (node ).__name__ ,
79+ )
80+ for node in ast .walk (tree )
81+ if isinstance (node , ast .stmt )
82+ ),
83+ key = lambda record : (
84+ record .start_line ,
85+ record .end_line ,
86+ record .start_col ,
87+ record .end_col ,
88+ record .type_name ,
89+ ),
90+ )
91+ )
92+ start_lines = tuple (record .start_line for record in records )
93+ return records , start_lines
94+
95+
96+ def parsed_statement_index (
97+ filepath : str ,
98+ * ,
99+ ast_cache : dict [str , ast .AST | None ],
100+ stmt_index_cache : dict [str , _StatementIndex | None ],
101+ ) -> _StatementIndex | None :
102+ if filepath in stmt_index_cache :
103+ return stmt_index_cache [filepath ]
104+
105+ tree = parsed_file_tree (filepath , ast_cache = ast_cache )
106+ if tree is None :
107+ stmt_index_cache [filepath ] = None
108+ return None
109+
110+ index = _build_statement_index (tree )
111+ stmt_index_cache [filepath ] = index
112+ return index
113+
114+
53115def is_assert_like_stmt (statement : ast .stmt ) -> bool :
54116 if isinstance (statement , ast .Assert ):
55117 return True
@@ -72,52 +134,50 @@ def assert_range_stats(
72134 start_line : int ,
73135 end_line : int ,
74136 ast_cache : dict [str , ast .AST | None ],
137+ stmt_index_cache : dict [str , _StatementIndex | None ],
75138 range_cache : dict [tuple [str , int , int ], tuple [int , int , int ]],
76139) -> tuple [int , int , int ]:
77140 cache_key = (filepath , start_line , end_line )
78141 if cache_key in range_cache :
79142 return range_cache [cache_key ]
80143
81- tree = parsed_file_tree (filepath , ast_cache = ast_cache )
82- if tree is None :
144+ statement_index = parsed_statement_index (
145+ filepath ,
146+ ast_cache = ast_cache ,
147+ stmt_index_cache = stmt_index_cache ,
148+ )
149+ if statement_index is None :
83150 range_cache [cache_key ] = (0 , 0 , 0 )
84151 return 0 , 0 , 0
85152
86- statements = [
87- node
88- for node in ast .walk (tree )
89- if isinstance (node , ast .stmt )
90- and int (getattr (node , "lineno" , 0 )) >= start_line
91- and int (getattr (node , "end_lineno" , 0 )) <= end_line
92- ]
93- if not statements :
153+ records , start_lines = statement_index
154+ if not records :
94155 range_cache [cache_key ] = (0 , 0 , 0 )
95156 return 0 , 0 , 0
96157
97- ordered_statements = sorted (
98- statements ,
99- key = lambda statement : (
100- int (getattr (statement , "lineno" , 0 )),
101- int (getattr (statement , "end_lineno" , 0 )),
102- int (getattr (statement , "col_offset" , 0 )),
103- int (getattr (statement , "end_col_offset" , 0 )),
104- type (statement ).__name__ ,
105- ),
106- )
158+ left = bisect_left (start_lines , start_line )
159+ right = bisect_right (start_lines , end_line )
160+ if left >= right :
161+ range_cache [cache_key ] = (0 , 0 , 0 )
162+ return 0 , 0 , 0
107163
108- total = len ( ordered_statements )
109- assert_like = 0
110- max_consecutive = 0
111- current_consecutive = 0
112- for statement in ordered_statements :
113- if is_assert_like_stmt (statement ):
164+ total , assert_like , max_consecutive , current_consecutive = ( 0 , 0 , 0 , 0 )
165+ for record in records [ left : right ]:
166+ if record . end_line > end_line :
167+ continue
168+ total += 1
169+ if is_assert_like_stmt (record . node ):
114170 assert_like += 1
115171 current_consecutive += 1
116172 if current_consecutive > max_consecutive :
117173 max_consecutive = current_consecutive
118174 else :
119175 current_consecutive = 0
120176
177+ if total == 0 :
178+ range_cache [cache_key ] = (0 , 0 , 0 )
179+ return 0 , 0 , 0
180+
121181 stats = (total , assert_like , max_consecutive )
122182 range_cache [cache_key ] = stats
123183 return stats
@@ -129,13 +189,15 @@ def is_assert_only_range(
129189 start_line : int ,
130190 end_line : int ,
131191 ast_cache : dict [str , ast .AST | None ],
192+ stmt_index_cache : dict [str , _StatementIndex | None ],
132193 range_cache : dict [tuple [str , int , int ], tuple [int , int , int ]],
133194) -> bool :
134195 total , assert_like , _ = assert_range_stats (
135196 filepath = filepath ,
136197 start_line = start_line ,
137198 end_line = end_line ,
138199 ast_cache = ast_cache ,
200+ stmt_index_cache = stmt_index_cache ,
139201 range_cache = range_cache ,
140202 )
141203 return total > 0 and total == assert_like
@@ -163,6 +225,7 @@ def enrich_with_assert_facts(
163225 facts : dict [str , str ],
164226 items : GroupItemsLike ,
165227 ast_cache : dict [str , ast .AST | None ],
228+ stmt_index_cache : dict [str , _StatementIndex | None ],
166229 range_cache : dict [tuple [str , int , int ], tuple [int , int , int ]],
167230) -> None :
168231 assert_only = True
@@ -187,6 +250,7 @@ def enrich_with_assert_facts(
187250 start_line = start_line ,
188251 end_line = end_line ,
189252 ast_cache = ast_cache ,
253+ stmt_index_cache = stmt_index_cache ,
190254 range_cache = range_cache ,
191255 )
192256 total_statements += range_total
@@ -205,6 +269,7 @@ def enrich_with_assert_facts(
205269 start_line = start_line ,
206270 end_line = end_line ,
207271 ast_cache = ast_cache ,
272+ stmt_index_cache = stmt_index_cache ,
208273 range_cache = range_cache ,
209274 )
210275 ):
@@ -230,6 +295,7 @@ def build_block_group_facts(block_groups: GroupMapLike) -> dict[str, dict[str, s
230295 Renderers (HTML/TXT/JSON) should only display these facts.
231296 """
232297 ast_cache : dict [str , ast .AST | None ] = {}
298+ stmt_index_cache : dict [str , _StatementIndex | None ] = {}
233299 range_cache : dict [tuple [str , int , int ], tuple [int , int , int ]] = {}
234300 facts_by_group : dict [str , dict [str , str ]] = {}
235301
@@ -239,6 +305,7 @@ def build_block_group_facts(block_groups: GroupMapLike) -> dict[str, dict[str, s
239305 facts = facts ,
240306 items = items ,
241307 ast_cache = ast_cache ,
308+ stmt_index_cache = stmt_index_cache ,
242309 range_cache = range_cache ,
243310 )
244311 group_arity = len (items )
0 commit comments