@@ -35,6 +35,9 @@ def __init__(self, args=None, python_version: Optional[str] = None) -> None:
3535 if python_version is None and args is not None :
3636 python_version = getattr (args , "python_version" , None )
3737 self ._python_version = python_version
38+ self ._major , self ._minor = self ._parse_python_version (self ._python_version )
39+ self ._ast_mod = ast27 if self ._major == 2 else ast
40+ self ._feature_version = self ._minor if self ._major == 3 else None
3841 self ._module_names_set : Set [str ] = set ()
3942
4043 def get_source_files (self , paths ) -> List [str ]:
@@ -116,16 +119,14 @@ def _read_file_content(self, path: Text) -> Text:
116119 return file_read .read ()
117120
118121 def _parse_source (self , source : Text , filename : Text ):
119- major , minor = self ._parse_python_version (self ._python_version )
120- if major == 2 :
122+ if self ._major == 2 :
121123 if ast27 is None :
122124 raise ImportError (
123125 "typed_ast is required to parse Python 2 source code."
124126 )
125127 return ast27 .parse (source , filename = filename , mode = "exec" )
126128
127- feature_version = minor if minor is not None else None
128- return self ._parse_with_feature_version (source , filename , feature_version )
129+ return self ._parse_with_feature_version (source , filename , self ._feature_version )
129130
130131 def _parse_with_feature_version (self , source : Text , filename : Text , feature_version : Optional [int ]):
131132 if feature_version is None :
@@ -159,26 +160,25 @@ def _parse_python_version(self, version: Optional[str]) -> Tuple[int, Optional[i
159160 def _extract_entities (self , ast_tree , filename : Text ) -> (List [object ], Dict [str , object ]):
160161 entities : List [object ] = []
161162 entity_nodes : Dict [str , object ] = {}
162- ast_mod = ast27 if self ._python_version and self ._python_version .startswith ("2" ) else ast
163-
164- async_def = getattr (ast_mod , "AsyncFunctionDef" , None )
163+ ast_mod = self ._ast_mod
164+ async_def = getattr (self ._ast_mod , "AsyncFunctionDef" , None )
165165
166166 for node in getattr (ast_tree , "body" , []):
167- if isinstance (node , ast_mod .FunctionDef ):
167+ if isinstance (node , self . _ast_mod .FunctionDef ):
168168 func = Function (node .name , filename , node .lineno )
169- func .endno = self ._get_end_lineno (node , ast_mod )
169+ func .endno = self ._get_end_lineno (node , self . _ast_mod )
170170 entities .append (func )
171171 entity_nodes [node .name ] = node
172172 elif async_def and isinstance (node , async_def ):
173173 func = AsyncFunction (node .name , filename , node .lineno )
174- func .endno = self ._get_end_lineno (node , ast_mod )
174+ func .endno = self ._get_end_lineno (node , self . _ast_mod )
175175 entities .append (func )
176176 entity_nodes [node .name ] = node
177- elif isinstance (node , ast_mod .ClassDef ):
177+ elif isinstance (node , self . _ast_mod .ClassDef ):
178178 bases = [self ._get_name_from_expr (base , ast_mod ) for base in node .bases ]
179179 bases = [b for b in bases if b ]
180180 cls = Class (node .name , bases , filename , node .lineno )
181- cls .endno = self ._get_end_lineno (node , ast_mod )
181+ cls .endno = self ._get_end_lineno (node , self . _ast_mod )
182182 entities .append (cls )
183183 entity_nodes [node .name ] = node
184184
@@ -188,7 +188,7 @@ def _collect_imports(self, ast_tree) -> ImportInfo:
188188 module_aliases : Dict [str , str ] = {}
189189 entity_aliases : Dict [str , str ] = {}
190190 module_imports : Set [str ] = set ()
191- ast_mod = ast27 if self ._python_version and self . _python_version . startswith ( "2" ) else ast
191+ ast_mod = self ._ast_mod
192192
193193 for node in getattr (ast_tree , "body" , []):
194194 if isinstance (node , ast_mod .Import ):
@@ -207,10 +207,8 @@ def _collect_imports(self, ast_tree) -> ImportInfo:
207207 full_name = f"{ base } .{ name } " if base else name
208208 resolved = self ._resolve_imported_module (full_name )
209209 if resolved :
210- module_aliases [alias_name ] = resolved
211210 module_imports .add (resolved )
212- else :
213- entity_aliases [alias_name ] = full_name
211+ entity_aliases [alias_name ] = full_name
214212
215213 return ImportInfo (
216214 module_aliases = module_aliases ,
@@ -234,14 +232,14 @@ def _collect_dependencies_in_module(
234232 entity_aliases : Dict [str , str ],
235233 ) -> List [str ]:
236234 deps : List [str ] = []
237- ast_mod = ast27 if self ._python_version and self . _python_version . startswith ( "2" ) else ast
235+ ast_mod = self ._ast_mod
238236 collector = self ._make_dependency_collector (
239237 local_entities , module_aliases , entity_aliases , ast_mod
240238 )
241239 for node in getattr (ast_tree , "body" , []):
242240 if isinstance (node , (ast_mod .FunctionDef , ast_mod .ClassDef )):
243241 continue
244- async_def = getattr (ast_mod , "AsyncFunctionDef" , None )
242+ async_def = getattr (self . _ast_mod , "AsyncFunctionDef" , None )
245243 if async_def and isinstance (node , async_def ):
246244 continue
247245 collector .visit (node )
@@ -256,7 +254,7 @@ def _collect_dependencies_in_entity(
256254 entity_aliases : Dict [str , str ],
257255 ) -> List [str ]:
258256 deps : List [str ] = []
259- ast_mod = ast27 if self ._python_version and self . _python_version . startswith ( "2" ) else ast
257+ ast_mod = self ._ast_mod
260258
261259 if isinstance (node , ast_mod .ClassDef ):
262260 for base in node .bases :
@@ -339,7 +337,9 @@ def _resolve_attribute(
339337 suffix = "." .join (parts [1 :])
340338 return f"{ module_name } .{ suffix } " if suffix else module_name
341339 if base in entity_aliases :
342- return entity_aliases [base ]
340+ base_name = entity_aliases [base ]
341+ suffix = "." .join (parts [1 :])
342+ return f"{ base_name } .{ suffix } " if suffix else base_name
343343 if base in local_entities :
344344 return base
345345 return None
@@ -369,6 +369,14 @@ def _flatten_attribute(self, node, ast_mod) -> List[str]:
369369 return parts
370370 return []
371371
372+ def _get_name_from_expr (self , node , ast_mod ) -> Optional [str ]:
373+ if isinstance (node , ast_mod .Name ):
374+ return node .id
375+ if isinstance (node , ast_mod .Attribute ):
376+ parts = self ._flatten_attribute (node , ast_mod )
377+ return "." .join (parts ) if parts else None
378+ return None
379+
372380 def _get_end_lineno (self , node , ast_mod ) -> int :
373381 end_lineno = getattr (node , "end_lineno" , None )
374382 if end_lineno :
0 commit comments