1+ from __future__ import annotations
12from copy import copy
23from inspect import isclass , signature , Signature , getmodule
34from typing import (
@@ -396,6 +397,8 @@ class EvaluationContext:
396397 policy_overrides : dict = field (default_factory = dict )
397398 #: Transient local namespace used to store mocks
398399 transient_locals : dict = field (default_factory = dict )
400+ #: Tranisents of class level
401+ class_transients : dict | None = None
399402
400403 def replace (self , / , ** changes ):
401404 """Return a new copy of the context, with specified changes"""
@@ -560,6 +563,7 @@ def _validate_policy_overrides(
560563def _handle_assign (node : ast .Assign , context : EvaluationContext ):
561564 value = eval_node (node .value , context )
562565 transient_locals = context .transient_locals
566+ class_transients = getattr (context , "class_transients" , None )
563567 for target in node .targets :
564568 if isinstance (target , (ast .Tuple , ast .List )):
565569 # Handle unpacking assignment
@@ -573,21 +577,55 @@ def _handle_assign(node: ast.Assign, context: EvaluationContext):
573577 # Before starred
574578 for i in range (star_or_last_idx ):
575579 transient_locals [targets [i ].id ] = values [i ]
580+ # Check for self.x assignment
581+ if class_transients is not None and hasattr (targets [i ], "ctx" ):
582+ if (
583+ isinstance (targets [i ], ast .Attribute )
584+ and isinstance (targets [i ].value , ast .Name )
585+ and targets [i ].value .id == "self"
586+ ):
587+ class_transients [targets [i ].attr ] = values [i ]
576588
577589 # Starred if exists
578590 if starred :
579591 end = len (values ) - (len (targets ) - star_or_last_idx - 1 )
580592 transient_locals [targets [star_or_last_idx ].value .id ] = values [
581593 star_or_last_idx :end
582594 ]
595+ if (
596+ class_transients is not None
597+ and isinstance (targets [star_or_last_idx ], ast .Attribute )
598+ and isinstance (targets [star_or_last_idx ].value , ast .Name )
599+ and targets [star_or_last_idx ].value .id == "self"
600+ ):
601+ class_transients [targets [star_or_last_idx ].attr ] = values [
602+ star_or_last_idx :end
603+ ]
583604
584605 # After starred
585606 for i in range (star_or_last_idx + 1 , len (targets )):
586607 transient_locals [targets [i ].id ] = values [
587608 len (values ) - (len (targets ) - i )
588609 ]
610+ if (
611+ class_transients is not None
612+ and isinstance (targets [i ], ast .Attribute )
613+ and isinstance (targets [i ].value , ast .Name )
614+ and targets [i ].value .id == "self"
615+ ):
616+ class_transients [targets [i ].attr ] = values [
617+ len (values ) - (len (targets ) - i )
618+ ]
589619 else :
590- transient_locals [target .id ] = value
620+ if (
621+ isinstance (target , ast .Attribute )
622+ and isinstance (target .value , ast .Name )
623+ and target .value .id == "self"
624+ ):
625+ if class_transients is not None :
626+ class_transients [target .attr ] = value
627+ elif hasattr (target , "id" ):
628+ transient_locals [target .id ] = value
591629 return None
592630
593631
@@ -643,6 +681,10 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
643681 return result
644682 if isinstance (node , ast .FunctionDef ):
645683 # we ignore body and only extract the return type
684+ func_locals = context .transient_locals .copy ()
685+ func_context = context .replace (transient_locals = func_locals )
686+ for child_node in node .body :
687+ eval_node (child_node , func_context )
646688 is_property = False
647689
648690 for decorator_node in node .decorator_list :
@@ -662,7 +704,7 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
662704 return_type , context
663705 )
664706 else :
665- return_value = _infer_return_value (node , context )
707+ return_value = _infer_return_value (node , func_context )
666708 context .transient_locals [node .name ] = return_value
667709
668710 return None
@@ -673,7 +715,7 @@ def dummy_function(*args, **kwargs):
673715 if return_type is not None :
674716 dummy_function .__annotations__ ["return" ] = return_type
675717 else :
676- inferred_return = _infer_return_value (node , context )
718+ inferred_return = _infer_return_value (node , func_context )
677719 if inferred_return is not None :
678720 dummy_function .__inferred_return__ = inferred_return
679721
@@ -685,6 +727,7 @@ def dummy_function(*args, **kwargs):
685727 # TODO support class decorators?
686728 class_locals = context .transient_locals .copy ()
687729 class_context = context .replace (transient_locals = class_locals )
730+ class_context .class_transients = class_locals
688731 for child_node in node .body :
689732 eval_node (child_node , class_context )
690733 bases = tuple ([eval_node (base , context ) for base in node .bases ])
@@ -694,12 +737,19 @@ def dummy_function(*args, **kwargs):
694737 if isinstance (node , ast .Assign ):
695738 return _handle_assign (node , context )
696739 if isinstance (node , ast .AnnAssign ):
697- if not node .simple :
698- # for now only handle simple annotations
699- return None
700- context .transient_locals [node .target .id ] = _resolve_annotation (
701- eval_node (node .annotation , context ), context
702- )
740+ if node .simple :
741+ value = _resolve_annotation (eval_node (node .annotation , context ), context )
742+ context .transient_locals [node .target .id ] = value
743+ # Handle non-simple annotated assignments only for self.x: type = value
744+ class_transients = getattr (context , "class_transients" , None )
745+ if (
746+ class_transients is not None
747+ and isinstance (node .target , ast .Attribute )
748+ and isinstance (node .target .value , ast .Name )
749+ and node .target .value .id == "self"
750+ ):
751+ value = _resolve_annotation (eval_node (node .annotation , context ), context )
752+ class_transients [node .target .attr ] = value
703753 return None
704754 if isinstance (node , ast .Expression ):
705755 return eval_node (node .body , context )
@@ -897,22 +947,9 @@ def _infer_return_value(node: ast.FunctionDef, context: EvaluationContext):
897947
898948
899949def _collect_return_values (body , context ):
900- """Recursively collect return values from a list of AST statements.
901-
902- For every assignment or annotated assignment, store them in context.transient_locals
903- so that return statements can refer to them.
904- """
950+ """Recursively collect return values from a list of AST statements."""
905951 return_values = []
906952 for stmt in body :
907- # Handle assignments
908- if isinstance (stmt , ast .Assign ):
909- _handle_assign (stmt , context )
910- elif isinstance (stmt , ast .AnnAssign ):
911- if stmt .simple :
912- context .transient_locals [stmt .target .id ] = _resolve_annotation (
913- eval_node (stmt .annotation , context ), context
914- )
915- # Handle return statements
916953 if isinstance (stmt , ast .Return ):
917954 if stmt .value is None :
918955 continue
0 commit comments