11import collections
22import inspect
33import textwrap
4- from typing import Any , List , Optional , Tuple , Union
4+ from typing import Any , Optional , OrderedDict
55
66import matplotlib .pyplot as plt
77import networkx as nx
88import numpy as np
9+ import numpy .typing as npt
910import pydot
1011from matplotlib .axes import Axes
1112from matplotlib .colors import rgb2hex
@@ -24,7 +25,7 @@ def __init__(
2425 color : Optional [str ] = "#000000" ,
2526 ):
2627 self .child = child
27- self .child_name = child .__name__
28+ self .child_name : str = str ( child .__name__ )
2829 self ._child_id = child_id
2930 self .parent = parent
3031
@@ -41,10 +42,13 @@ def child_id(self) -> str:
4142 return str (self ._child_id )
4243
4344 @property
44- def parent_id (self ) -> str :
45+ def parent_id (self ) -> str | None :
4546 if self ._parent_id :
4647 return str (self ._parent_id )
47- return
48+ return None
49+
50+
51+ _similarity_container_types = PycodeSimilarity
4852
4953
5054class ClassGraphTree :
@@ -76,38 +80,44 @@ def __init__(
7680 self ,
7781 baseclass : Any ,
7882 funcname : Optional [str ] = None ,
79- default_color : Optional [ str ] = "#000000" ,
80- func_override_color : Optional [ str ] = "#ff0000" ,
81- similarity_cutoff : Optional [ float ] = 0.75 ,
82- max_recursion_level : Optional [ int ] = 500 ,
83- classes_to_exclude : Optional [List [str ]] = None ,
83+ default_color : str = "#000000" ,
84+ func_override_color : str = "#ff0000" ,
85+ similarity_cutoff : float = 0.75 ,
86+ max_recursion_level : int = 500 ,
87+ classes_to_exclude : Optional [list [str ]] = None ,
8488 ):
8589
8690 self .baseclass = baseclass
8791 self .basename : str = baseclass .__name__
8892 self .funcname = funcname
93+ self ._tracking_function = self .funcname is not None
8994 self .max_recursion_level = max_recursion_level
9095 self ._nodenum : int = 0
91- self ._node_list = [] # a list of unique ChildNodes
92- self ._node_map = {} # map of global node index to node name
93- self ._override_src = collections .OrderedDict ()
94- self ._override_src_files = {}
96+ self ._node_list : list [ _ChildNode ] = [] # a list of unique ChildNodes
97+ self ._node_map : dict [ int , str ] = {} # map of global node index to node name
98+ self ._override_src : OrderedDict [ int , str ] = collections .OrderedDict ()
99+ self ._override_src_files : dict [ int , str ] = {}
95100 self ._current_node = 1 # the current global node, must start at 1
96101 self ._default_color = default_color
97102 self ._override_color = func_override_color
98- self ._graphviz_args_kwargs = {}
99- self .similarity_container = None
100- self .similarity_results = None
103+ self ._graphviz_args_kwargs : dict [ str , Any ] = {}
104+ self .similarity_container : _similarity_container_types | None = None
105+ self .similarity_results : dict [ str , npt . NDArray [ Any ]]
101106 self .similarity_cutoff = similarity_cutoff
102107 if classes_to_exclude is None :
103108 classes_to_exclude = []
104109 self .classes_to_exclude = classes_to_exclude
105110 self ._build ()
106- self ._node_map_r = {v : k for k , v in self ._node_map .items ()} # name to index
111+ self ._node_map_r : dict [str , int ] = {
112+ v : k for k , v in self ._node_map .items ()
113+ } # name to index
107114
108115 def _get_source_info (self , obj ) -> Optional [str ]:
109- f = getattr (obj , self .funcname )
110- if isinstance (f , collections .abc .Callable ):
116+ if self .funcname is None :
117+ raise RuntimeError ("this functionality requires function tracking." )
118+ fname : str = self .funcname
119+ f = getattr (obj , fname )
120+ if isinstance (f , collections .abc .Callable ): # type: ignore[arg-type]
111121 return f"{ inspect .getsourcefile (f )} :{ inspect .getsourcelines (f )[1 ]} "
112122 return None
113123
@@ -157,8 +167,12 @@ def _store_node_func_source(self, clss, current_node: int):
157167 # store the source code of funcname for the current class and node
158168 # clss: a class
159169 # current_node: the
160- f = getattr (clss , self .funcname )
161- if isinstance (f , collections .abc .Callable ):
170+ if self .funcname is None :
171+ raise RuntimeError ("this functionality requires function tracking." )
172+ fname : str = self .funcname
173+
174+ f = getattr (clss , fname )
175+ if isinstance (f , collections .abc .Callable ): # type: ignore[arg-type]
162176 src = textwrap .dedent (inspect .getsource (f ))
163177 self ._override_src_files [current_node ] = (
164178 f"{ inspect .getsourcefile (f )} :{ inspect .getsourcelines (f )[1 ]} "
@@ -167,21 +181,28 @@ def _store_node_func_source(self, clss, current_node: int):
167181
168182 def check_source_similarity (
169183 self ,
170- SimilarityContainer = PycodeSimilarity ,
171- method = "reference" ,
184+ similarity_container_class : str = " PycodeSimilarity" ,
185+ method : str = "reference" ,
172186 reference : Optional [int ] = None ,
173187 ):
174188 # compares all the source code of the child methods that have
175189 # over-ridden funcname
176190
177191 if reference is None :
178- reference = 1 # use whatever the basenode is
192+ ref = 1 # use whatever the basenode is
193+ else :
194+ ref = reference
195+
196+ if similarity_container_class == "PycodeSimilarity" :
197+ SimClass = PycodeSimilarity
198+ else :
199+ raise ValueError (f"unexpected value, { similarity_container_class = } " )
179200
180- self .similarity_container = SimilarityContainer (method = method )
181- sim = self .similarity_container .run (self ._override_src , reference = reference )
201+ self .similarity_container = SimClass (method = method )
202+ sim = self .similarity_container .run (self ._override_src , reference = ref )
182203 return sim
183204
184- def _build (self ):
205+ def _build (self ) -> None :
185206
186207 # construct the first node
187208 color = self ._get_baseclass_color ()
@@ -201,11 +222,12 @@ def _build(self):
201222 # construct the full similarity matrix
202223 s_c = PycodeSimilarity (method = "permute" )
203224 _ , sim_matrix , sim_axis = s_c .run (self ._override_src )
204- sim_axis = np .array (sim_axis )
225+ assert isinstance (sim_matrix , np .ndarray )
226+ sim_axis_array = np .array (sim_axis )
205227 sim_axis_names = np .array ([c .child_name for c in self ._node_list ])
206228 self .similarity_results = {
207229 "matrix" : sim_matrix ,
208- "axis" : sim_axis ,
230+ "axis" : sim_axis_array ,
209231 "axis_names" : sim_axis_names ,
210232 }
211233
@@ -216,9 +238,9 @@ def _build(self):
216238 rowvals = M [irow , :]
217239 indxs = np .where (rowvals >= cutoff_sim )[0 ]
218240 indxs = indxs [indxs != irow ] # these are matrix indeces
219- node_ids = sim_axis [indxs ]
241+ node_ids = sim_axis_array [indxs ]
220242 if len (node_ids ) > 0 :
221- this_child = sim_axis [irow ]
243+ this_child = sim_axis_array [irow ]
222244 similarity_sets [this_child ] = set (node_ids .tolist ())
223245 self .similarity_sets = similarity_sets
224246
@@ -292,7 +314,7 @@ def plot_similarity(
292314 ax : Optional [Axes ] = None ,
293315 colorbar : Optional [bool ] = True ,
294316 ** kwargs ,
295- ) -> Tuple [dict , Axes ]:
317+ ) -> tuple [dict [ int , str ] , Axes ]:
296318 """
297319 add the similarity plot to a matplotlib axis (or create a new one)
298320
@@ -340,16 +362,16 @@ def plot_similarity(
340362 sim_labels = [
341363 self ._node_list [cid - 1 ].child_name for cid in self ._override_src .keys ()
342364 ]
343- sim_labels = {lid : label for lid , label in enumerate (sim_labels )}
344- return sim_labels , ax
365+ sim_labels_dict = {lid : label for lid , label in enumerate (sim_labels )}
366+ return sim_labels_dict , ax
345367
346368 def build_interactive_graph (
347369 self ,
348370 include_similarity : bool = True ,
349- node_style : dict = None ,
350- edge_style : dict = None ,
351- similarity_edge_style : dict = None ,
352- override_node_color : Union [ str , tuple ] = None ,
371+ node_style : dict [ str , Any ] | None = None ,
372+ edge_style : dict [ str , Any ] | None = None ,
373+ similarity_edge_style : dict [ str , Any ] | None = None ,
374+ override_node_color : str | tuple [ float , ...] | None = None ,
353375 ** kwargs ,
354376 ) -> Network :
355377 """
@@ -454,34 +476,41 @@ def build_interactive_graph(
454476 network_wrapper .from_nx (grph )
455477 return network_wrapper
456478
457- def get_source_code (self , node : Union [ str , int ] ) -> str :
479+ def get_source_code (self , node : int | str ) -> str :
458480 """
459481 retrieve the source code of the comparison function for a
460482 specified node
461483
462484 Parameters
463485 ----------
464- node: Union[str, int]
486+ node: int
465487 the node to fetch the source code for
466488
467489 Returns
468490 -------
469491 str
470492 a string containing the source code for the node.
471493 """
472- if node in self ._override_src :
473- return self ._override_src [node ]
494+ node_id : int
495+
496+ if not isinstance (node , int ) and not isinstance (node , str ):
497+ raise TypeError ("Unexpected type for node" )
498+
499+ if isinstance (node , int ) and node in self ._node_map :
500+ node_id = node
474501 elif isinstance (node , str ) and node in self ._node_map_r :
475502 node_id = self ._node_map_r [node ]
476- if node_id in self . _override_src :
477- return self . _override_src [ node_id ]
478- else :
479- raise ValueError (
480- f"node { node } does not override the " f"chosen function."
481- )
482- raise KeyError (f"Could not find node for { node } " )
503+ else :
504+ raise ValueError ( f"Could not find node for { node } " )
505+
506+ if node_id in self . _override_src :
507+ return self . _override_src [ node_id ]
508+ else :
509+ raise ValueError (f"node { node } does not override the chosen function. " )
483510
484- def get_multiple_source_code (self , node_1 : Union [str , int ], * args ) -> dict :
511+ def get_multiple_source_code (
512+ self , node_1 : int | str , * args
513+ ) -> dict [int | str , str ]:
485514 """
486515 Retrieve the source code for multiple nodes
487516
@@ -515,11 +544,11 @@ def display_code_comparison(self):
515544 display_code_compare (self )
516545
517546
518- def _validate_color (clr , default_rgb_tuple : tuple ) -> str :
547+ def _validate_color (clr , default_rgb_tuple : tuple [ float , float , float ] ) -> str :
519548 if clr is None :
520- return rgb2hex (default_rgb_tuple )
549+ return str ( rgb2hex (default_rgb_tuple ) )
521550 elif isinstance (clr , tuple ):
522- return rgb2hex (clr )
551+ return str ( rgb2hex (clr ) )
523552 elif isinstance (clr , str ):
524553 return clr
525554 msg = f"clr has unexpected type: { type (clr )} "
0 commit comments