Skip to content

Commit ecde164

Browse files
authored
Merge pull request #42 from chrishavlin/enable_typechecking_ci
add mypy check
2 parents 79263ba + 918973b commit ecde164

16 files changed

Lines changed: 311 additions & 138 deletions

.coveragerc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[run]
2+
source = inheritance_explorer
3+
4+
[report]
5+
fail_under = 97
Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
name: Run tests
22

33
on:
4-
pull_request:
5-
paths:
6-
- '**.py'
7-
- '**run-tests.yml'
4+
pull_request:
85
schedule:
96
- cron: "30 1 * * 1"
107

.github/workflows/type-check.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
name: type checking
2+
3+
on:
4+
pull_request:
5+
paths:
6+
- inheritance_explorer/**/*.py
7+
- pyproject.toml
8+
- requirements/typecheck.txt
9+
- .github/workflows/type-checking.yaml
10+
workflow_dispatch:
11+
12+
jobs:
13+
build:
14+
runs-on: ubuntu-latest
15+
name: type check
16+
timeout-minutes: 60
17+
18+
steps:
19+
- name: Checkout repo
20+
uses: actions/checkout@v4
21+
22+
- name: Set up Python
23+
uses: actions/setup-python@v5
24+
with:
25+
# run with oldest supported python version
26+
# so that we always get compatible versions of
27+
# core dependencies at type-check time
28+
python-version: '3.10'
29+
30+
- name: Build
31+
run: |
32+
python3 -m pip install --upgrade pip
33+
python3 -m pip install -r requirements/typecheck.txt
34+
35+
- name: list installed deps
36+
run: python -m pip list
37+
38+
- name: Run mypy
39+
run: mypy inheritance_explorer

HISTORY.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
History
33
=======
44

5+
v0.2.0dev
6+
---------
7+
* Improved typing (`PR 42 <https://github.com/data-exp-lab/inheritance_explorer/pull/42>`_)
8+
59
v0.2.0
610
------
711
* enable styling on interactive graphs (see `new example <https://inheritance-explorer.readthedocs.io/en/latest/examples/ex_006_interactive_graph_styles.html>`_)

docs/examples/ex_001_basic_usage.ipynb

Lines changed: 6 additions & 14 deletions
Large diffs are not rendered by default.

docs/examples/ex_002_inheritance_scope.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@
265265
"name": "python",
266266
"nbconvert_exporter": "python",
267267
"pygments_lexer": "ipython3",
268-
"version": "3.9.0"
268+
"version": "3.10.11"
269269
}
270270
},
271271
"nbformat": 4,

docs/examples/ex_003_Qwidgets.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@
214214
"name": "python",
215215
"nbconvert_exporter": "python",
216216
"pygments_lexer": "ipython3",
217-
"version": "3.9.0"
217+
"version": "3.10.11"
218218
}
219219
},
220220
"nbformat": 4,

inheritance_explorer/_testing.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
class ClassForTesting:
2-
def use_this_func(self, a):
2+
misc_attr: int = 1
3+
4+
def use_this_func(self, a: int) -> int:
35
return a
46

57

68
class ClassForTesting2(ClassForTesting):
7-
def use_this_func(self, a):
9+
def use_this_func(self, a: int) -> int:
810
b = a * 10
911
return b
1012

@@ -14,7 +16,7 @@ class ClassForTesting3(ClassForTesting):
1416

1517

1618
class ClassForTesting4(ClassForTesting2):
17-
def use_this_func(self, a):
19+
def use_this_func(self, a: int) -> int:
1820
b = a * 10
1921
c = b + 10
2022
return c

inheritance_explorer/_widget_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def update_source_2(event):
7373

7474
class_dropdown_1.observe(update_source_1, ["value"])
7575
class_dropdown_2.observe(update_source_2, ["value"])
76-
update_source_1(None)
77-
update_source_2(None)
76+
update_source_1(None) # type: ignore[no-untyped-call]
77+
update_source_2(None) # type: ignore[no-untyped-call]
7878

7979
if class_1_name is not None:
8080
class_dropdown_1.value = class_1_name

inheritance_explorer/inheritance_explorer.py

Lines changed: 82 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import collections
22
import inspect
33
import textwrap
4-
from typing import Any, List, Optional, Tuple, Union
4+
from typing import Any, Optional, OrderedDict
55

66
import matplotlib.pyplot as plt
77
import networkx as nx
88
import numpy as np
9+
import numpy.typing as npt
910
import pydot
1011
from matplotlib.axes import Axes
1112
from matplotlib.colors import rgb2hex
@@ -24,7 +25,7 @@ def __init__(
2425
color: Optional[str] = "#000000",
2526
):
2627
self.child = child
27-
self.child_name = child.__name__
28+
self.child_name: str = str(child.__name__)
2829
self._child_id = child_id
2930
self.parent = parent
3031

@@ -41,10 +42,13 @@ def child_id(self) -> str:
4142
return str(self._child_id)
4243

4344
@property
44-
def parent_id(self) -> str:
45+
def parent_id(self) -> str | None:
4546
if self._parent_id:
4647
return str(self._parent_id)
47-
return
48+
return None
49+
50+
51+
_similarity_container_types = PycodeSimilarity
4852

4953

5054
class ClassGraphTree:
@@ -76,38 +80,44 @@ def __init__(
7680
self,
7781
baseclass: Any,
7882
funcname: Optional[str] = None,
79-
default_color: Optional[str] = "#000000",
80-
func_override_color: Optional[str] = "#ff0000",
81-
similarity_cutoff: Optional[float] = 0.75,
82-
max_recursion_level: Optional[int] = 500,
83-
classes_to_exclude: Optional[List[str]] = None,
83+
default_color: str = "#000000",
84+
func_override_color: str = "#ff0000",
85+
similarity_cutoff: float = 0.75,
86+
max_recursion_level: int = 500,
87+
classes_to_exclude: Optional[list[str]] = None,
8488
):
8589

8690
self.baseclass = baseclass
8791
self.basename: str = baseclass.__name__
8892
self.funcname = funcname
93+
self._tracking_function = self.funcname is not None
8994
self.max_recursion_level = max_recursion_level
9095
self._nodenum: int = 0
91-
self._node_list = [] # a list of unique ChildNodes
92-
self._node_map = {} # map of global node index to node name
93-
self._override_src = collections.OrderedDict()
94-
self._override_src_files = {}
96+
self._node_list: list[_ChildNode] = [] # a list of unique ChildNodes
97+
self._node_map: dict[int, str] = {} # map of global node index to node name
98+
self._override_src: OrderedDict[int, str] = collections.OrderedDict()
99+
self._override_src_files: dict[int, str] = {}
95100
self._current_node = 1 # the current global node, must start at 1
96101
self._default_color = default_color
97102
self._override_color = func_override_color
98-
self._graphviz_args_kwargs = {}
99-
self.similarity_container = None
100-
self.similarity_results = None
103+
self._graphviz_args_kwargs: dict[str, Any] = {}
104+
self.similarity_container: _similarity_container_types | None = None
105+
self.similarity_results: dict[str, npt.NDArray[Any]]
101106
self.similarity_cutoff = similarity_cutoff
102107
if classes_to_exclude is None:
103108
classes_to_exclude = []
104109
self.classes_to_exclude = classes_to_exclude
105110
self._build()
106-
self._node_map_r = {v: k for k, v in self._node_map.items()} # name to index
111+
self._node_map_r: dict[str, int] = {
112+
v: k for k, v in self._node_map.items()
113+
} # name to index
107114

108115
def _get_source_info(self, obj) -> Optional[str]:
109-
f = getattr(obj, self.funcname)
110-
if isinstance(f, collections.abc.Callable):
116+
if self.funcname is None:
117+
raise RuntimeError("this functionality requires function tracking.")
118+
fname: str = self.funcname
119+
f = getattr(obj, fname)
120+
if isinstance(f, collections.abc.Callable): # type: ignore[arg-type]
111121
return f"{inspect.getsourcefile(f)}:{inspect.getsourcelines(f)[1]}"
112122
return None
113123

@@ -157,8 +167,12 @@ def _store_node_func_source(self, clss, current_node: int):
157167
# store the source code of funcname for the current class and node
158168
# clss: a class
159169
# current_node: the
160-
f = getattr(clss, self.funcname)
161-
if isinstance(f, collections.abc.Callable):
170+
if self.funcname is None:
171+
raise RuntimeError("this functionality requires function tracking.")
172+
fname: str = self.funcname
173+
174+
f = getattr(clss, fname)
175+
if isinstance(f, collections.abc.Callable): # type: ignore[arg-type]
162176
src = textwrap.dedent(inspect.getsource(f))
163177
self._override_src_files[current_node] = (
164178
f"{inspect.getsourcefile(f)}:{inspect.getsourcelines(f)[1]}"
@@ -167,21 +181,28 @@ def _store_node_func_source(self, clss, current_node: int):
167181

168182
def check_source_similarity(
169183
self,
170-
SimilarityContainer=PycodeSimilarity,
171-
method="reference",
184+
similarity_container_class: str = "PycodeSimilarity",
185+
method: str = "reference",
172186
reference: Optional[int] = None,
173187
):
174188
# compares all the source code of the child methods that have
175189
# over-ridden funcname
176190

177191
if reference is None:
178-
reference = 1 # use whatever the basenode is
192+
ref = 1 # use whatever the basenode is
193+
else:
194+
ref = reference
195+
196+
if similarity_container_class == "PycodeSimilarity":
197+
SimClass = PycodeSimilarity
198+
else:
199+
raise ValueError(f"unexpected value, {similarity_container_class=}")
179200

180-
self.similarity_container = SimilarityContainer(method=method)
181-
sim = self.similarity_container.run(self._override_src, reference=reference)
201+
self.similarity_container = SimClass(method=method)
202+
sim = self.similarity_container.run(self._override_src, reference=ref)
182203
return sim
183204

184-
def _build(self):
205+
def _build(self) -> None:
185206

186207
# construct the first node
187208
color = self._get_baseclass_color()
@@ -201,11 +222,12 @@ def _build(self):
201222
# construct the full similarity matrix
202223
s_c = PycodeSimilarity(method="permute")
203224
_, sim_matrix, sim_axis = s_c.run(self._override_src)
204-
sim_axis = np.array(sim_axis)
225+
assert isinstance(sim_matrix, np.ndarray)
226+
sim_axis_array = np.array(sim_axis)
205227
sim_axis_names = np.array([c.child_name for c in self._node_list])
206228
self.similarity_results = {
207229
"matrix": sim_matrix,
208-
"axis": sim_axis,
230+
"axis": sim_axis_array,
209231
"axis_names": sim_axis_names,
210232
}
211233

@@ -216,9 +238,9 @@ def _build(self):
216238
rowvals = M[irow, :]
217239
indxs = np.where(rowvals >= cutoff_sim)[0]
218240
indxs = indxs[indxs != irow] # these are matrix indeces
219-
node_ids = sim_axis[indxs]
241+
node_ids = sim_axis_array[indxs]
220242
if len(node_ids) > 0:
221-
this_child = sim_axis[irow]
243+
this_child = sim_axis_array[irow]
222244
similarity_sets[this_child] = set(node_ids.tolist())
223245
self.similarity_sets = similarity_sets
224246

@@ -292,7 +314,7 @@ def plot_similarity(
292314
ax: Optional[Axes] = None,
293315
colorbar: Optional[bool] = True,
294316
**kwargs,
295-
) -> Tuple[dict, Axes]:
317+
) -> tuple[dict[int, str], Axes]:
296318
"""
297319
add the similarity plot to a matplotlib axis (or create a new one)
298320
@@ -340,16 +362,16 @@ def plot_similarity(
340362
sim_labels = [
341363
self._node_list[cid - 1].child_name for cid in self._override_src.keys()
342364
]
343-
sim_labels = {lid: label for lid, label in enumerate(sim_labels)}
344-
return sim_labels, ax
365+
sim_labels_dict = {lid: label for lid, label in enumerate(sim_labels)}
366+
return sim_labels_dict, ax
345367

346368
def build_interactive_graph(
347369
self,
348370
include_similarity: bool = True,
349-
node_style: dict = None,
350-
edge_style: dict = None,
351-
similarity_edge_style: dict = None,
352-
override_node_color: Union[str, tuple] = None,
371+
node_style: dict[str, Any] | None = None,
372+
edge_style: dict[str, Any] | None = None,
373+
similarity_edge_style: dict[str, Any] | None = None,
374+
override_node_color: str | tuple[float, ...] | None = None,
353375
**kwargs,
354376
) -> Network:
355377
"""
@@ -454,34 +476,41 @@ def build_interactive_graph(
454476
network_wrapper.from_nx(grph)
455477
return network_wrapper
456478

457-
def get_source_code(self, node: Union[str, int]) -> str:
479+
def get_source_code(self, node: int | str) -> str:
458480
"""
459481
retrieve the source code of the comparison function for a
460482
specified node
461483
462484
Parameters
463485
----------
464-
node: Union[str, int]
486+
node: int
465487
the node to fetch the source code for
466488
467489
Returns
468490
-------
469491
str
470492
a string containing the source code for the node.
471493
"""
472-
if node in self._override_src:
473-
return self._override_src[node]
494+
node_id: int
495+
496+
if not isinstance(node, int) and not isinstance(node, str):
497+
raise TypeError("Unexpected type for node")
498+
499+
if isinstance(node, int) and node in self._node_map:
500+
node_id = node
474501
elif isinstance(node, str) and node in self._node_map_r:
475502
node_id = self._node_map_r[node]
476-
if node_id in self._override_src:
477-
return self._override_src[node_id]
478-
else:
479-
raise ValueError(
480-
f"node {node} does not override the " f"chosen function."
481-
)
482-
raise KeyError(f"Could not find node for {node}")
503+
else:
504+
raise ValueError(f"Could not find node for {node}")
505+
506+
if node_id in self._override_src:
507+
return self._override_src[node_id]
508+
else:
509+
raise ValueError(f"node {node} does not override the chosen function.")
483510

484-
def get_multiple_source_code(self, node_1: Union[str, int], *args) -> dict:
511+
def get_multiple_source_code(
512+
self, node_1: int | str, *args
513+
) -> dict[int | str, str]:
485514
"""
486515
Retrieve the source code for multiple nodes
487516
@@ -515,11 +544,11 @@ def display_code_comparison(self):
515544
display_code_compare(self)
516545

517546

518-
def _validate_color(clr, default_rgb_tuple: tuple) -> str:
547+
def _validate_color(clr, default_rgb_tuple: tuple[float, float, float]) -> str:
519548
if clr is None:
520-
return rgb2hex(default_rgb_tuple)
549+
return str(rgb2hex(default_rgb_tuple))
521550
elif isinstance(clr, tuple):
522-
return rgb2hex(clr)
551+
return str(rgb2hex(clr))
523552
elif isinstance(clr, str):
524553
return clr
525554
msg = f"clr has unexpected type: {type(clr)}"

0 commit comments

Comments
 (0)