55import re
66from typing import Optional , Callable
77
8- from parser import Stmt , SimpleStmt , BlockStmt , IfStmt , WhileStmt
8+ from parser import Stmt , SimpleStmt , BlockStmt , IfStmt , WhileStmt , ForStmt , MacroIfStmt
99
1010@dataclass
1111class EscapingCall :
@@ -723,53 +723,57 @@ def visit(stmt: Stmt) -> None:
723723 if error is not None :
724724 raise analysis_error (f"Escaping call '{ error .text } in condition" , error )
725725
726+ def escaping_call_in_simple_stmt (stmt : SimpleStmt , result : dict [Stmt , EscapingCall ]):
727+ tokens = stmt .contents
728+ for idx , tkn in enumerate (tokens ):
729+ try :
730+ next_tkn = tokens [idx + 1 ]
731+ except IndexError :
732+ break
733+ if next_tkn .kind != lexer .LPAREN :
734+ continue
735+ if tkn .kind == lexer .IDENTIFIER :
736+ if tkn .text .upper () == tkn .text :
737+ # simple macro
738+ continue
739+ #if not tkn.text.startswith(("Py", "_Py", "monitor")):
740+ # continue
741+ if tkn .text .startswith (("sym_" , "optimize_" , "PyJitRef" )):
742+ # Optimize functions
743+ continue
744+ if tkn .text .endswith ("Check" ):
745+ continue
746+ if tkn .text .startswith ("Py_Is" ):
747+ continue
748+ if tkn .text .endswith ("CheckExact" ):
749+ continue
750+ if tkn .text in NON_ESCAPING_FUNCTIONS :
751+ continue
752+ elif tkn .kind == "RPAREN" :
753+ prev = tokens [idx - 1 ]
754+ if prev .text .endswith ("_t" ) or prev .text == "*" or prev .text == "int" :
755+ #cast
756+ continue
757+ elif tkn .kind != "RBRACKET" :
758+ continue
759+ if tkn .text in ("PyStackRef_CLOSE" , "PyStackRef_XCLOSE" ):
760+ if len (tokens ) <= idx + 2 :
761+ raise analysis_error ("Unexpected end of file" , next_tkn )
762+ kills = tokens [idx + 2 ]
763+ if kills .kind != "IDENTIFIER" :
764+ raise analysis_error (f"Expected identifier, got '{ kills .text } '" , kills )
765+ else :
766+ kills = None
767+ result [stmt ] = EscapingCall (stmt , tkn , kills )
768+
769+
726770def find_escaping_api_calls (instr : parser .CodeDef ) -> dict [SimpleStmt , EscapingCall ]:
727771 result : dict [SimpleStmt , EscapingCall ] = {}
728772
729773 def visit (stmt : Stmt ) -> None :
730774 if not isinstance (stmt , SimpleStmt ):
731775 return
732- tokens = stmt .contents
733- for idx , tkn in enumerate (tokens ):
734- try :
735- next_tkn = tokens [idx + 1 ]
736- except IndexError :
737- break
738- if next_tkn .kind != lexer .LPAREN :
739- continue
740- if tkn .kind == lexer .IDENTIFIER :
741- if tkn .text .upper () == tkn .text :
742- # simple macro
743- continue
744- #if not tkn.text.startswith(("Py", "_Py", "monitor")):
745- # continue
746- if tkn .text .startswith (("sym_" , "optimize_" , "PyJitRef" )):
747- # Optimize functions
748- continue
749- if tkn .text .endswith ("Check" ):
750- continue
751- if tkn .text .startswith ("Py_Is" ):
752- continue
753- if tkn .text .endswith ("CheckExact" ):
754- continue
755- if tkn .text in NON_ESCAPING_FUNCTIONS :
756- continue
757- elif tkn .kind == "RPAREN" :
758- prev = tokens [idx - 1 ]
759- if prev .text .endswith ("_t" ) or prev .text == "*" or prev .text == "int" :
760- #cast
761- continue
762- elif tkn .kind != "RBRACKET" :
763- continue
764- if tkn .text in ("PyStackRef_CLOSE" , "PyStackRef_XCLOSE" ):
765- if len (tokens ) <= idx + 2 :
766- raise analysis_error ("Unexpected end of file" , next_tkn )
767- kills = tokens [idx + 2 ]
768- if kills .kind != "IDENTIFIER" :
769- raise analysis_error (f"Expected identifier, got '{ kills .text } '" , kills )
770- else :
771- kills = None
772- result [stmt ] = EscapingCall (stmt , tkn , kills )
776+ escaping_call_in_simple_stmt (stmt , result )
773777
774778 instr .block .accept (visit )
775779 check_escaping_calls (instr , result )
@@ -822,6 +826,56 @@ def stack_effect_only_peeks(instr: parser.InstDef) -> bool:
822826 )
823827
824828
829+ def op_escapes (op : parser .CodeDef ):
830+
831+ def is_simple_exit (stmt : Stmt ):
832+ if not isinstance (stmt , SimpleStmt ):
833+ return False
834+ tokens = stmt .contents
835+ if len (tokens ) < 4 :
836+ return False
837+ return (
838+ tokens [0 ].text in ("ERROR_IF" , "DEOPT_IF" , "EXIT_IF" )
839+ and
840+ tokens [1 ].text == "("
841+ and
842+ tokens [2 ].text in ("true" , "1" )
843+ and
844+ tokens [3 ].text == ")"
845+ )
846+
847+ def escapes_list (stmts : list [Stmt ]):
848+ if not stmts :
849+ return False
850+ if is_simple_exit (stmts [- 1 ]):
851+ return False
852+ for stmt in stmts :
853+ if escapes (stmt ):
854+ return True
855+ return False
856+
857+ def escapes (stmt : Stmt ) -> None :
858+ if isinstance (stmt , BlockStmt ):
859+ return escapes_list (stmt .body )
860+ elif isinstance (stmt , SimpleStmt ):
861+ d : dict [Stmt , EscapingCall ] = {}
862+ escaping_call_in_simple_stmt (stmt , d )
863+ return bool (d )
864+ elif isinstance (stmt , IfStmt ):
865+ if stmt .else_body and escapes (stmt .else_body ):
866+ return True
867+ return escapes (stmt .body )
868+ elif isinstance (stmt , MacroIfStmt ):
869+ if stmt .else_body and escapes_list (stmt .else_body ):
870+ return True
871+ return escapes_list (stmt .body )
872+ elif isinstance (stmt , ForStmt ):
873+ return escapes (stmt .body )
874+ elif isinstance (stmt , WhileStmt ):
875+ return escapes (stmt .body )
876+
877+ return escapes (op .block )
878+
825879def compute_properties (op : parser .CodeDef ) -> Properties :
826880 escaping_calls = find_escaping_api_calls (op )
827881 has_free = (
@@ -843,7 +897,7 @@ def compute_properties(op: parser.CodeDef) -> Properties:
843897 )
844898 error_with_pop = has_error_with_pop (op )
845899 error_without_pop = has_error_without_pop (op )
846- escapes = bool ( escaping_calls ) or variable_used (op , "DECREF_INPUTS" )
900+ escapes = op_escapes ( op ) or variable_used (op , "DECREF_INPUTS" )
847901 pure = False if isinstance (op , parser .LabelDef ) else "pure" in op .annotations
848902 no_save_ip = False if isinstance (op , parser .LabelDef ) else "no_save_ip" in op .annotations
849903 return Properties (
0 commit comments