Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[run]
source = inheritance_explorer

[report]
fail_under = 97
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
name: Run tests

on:
pull_request:
paths:
- '**.py'
- '**run-tests.yml'
pull_request:
schedule:
- cron: "30 1 * * 1"

Expand Down
39 changes: 39 additions & 0 deletions .github/workflows/type-check.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: type checking

on:
pull_request:
paths:
- inheritance_explorer/**/*.py
- pyproject.toml
- requirements/typecheck.txt
- .github/workflows/type-checking.yaml
workflow_dispatch:

jobs:
build:
runs-on: ubuntu-latest
name: type check
timeout-minutes: 60

steps:
- name: Checkout repo
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
# run with oldest supported python version
# so that we always get compatible versions of
# core dependencies at type-check time
python-version: '3.10'

- name: Build
run: |
python3 -m pip install --upgrade pip
python3 -m pip install -r requirements/typecheck.txt

- name: list installed deps
run: python -m pip list

- name: Run mypy
run: mypy inheritance_explorer
4 changes: 4 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
History
=======

v0.2.0dev
---------
* Improved typing (`PR 42 <https://github.com/data-exp-lab/inheritance_explorer/pull/42>`_)

v0.2.0
------
* enable styling on interactive graphs (see `new example <https://inheritance-explorer.readthedocs.io/en/latest/examples/ex_006_interactive_graph_styles.html>`_)
Expand Down
20 changes: 6 additions & 14 deletions docs/examples/ex_001_basic_usage.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/examples/ex_002_inheritance_scope.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.10.11"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/ex_003_Qwidgets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.10.11"
}
},
"nbformat": 4,
Expand Down
8 changes: 5 additions & 3 deletions inheritance_explorer/_testing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
class ClassForTesting:
def use_this_func(self, a):
misc_attr: int = 1

def use_this_func(self, a: int) -> int:
return a


class ClassForTesting2(ClassForTesting):
def use_this_func(self, a):
def use_this_func(self, a: int) -> int:
b = a * 10
return b

Expand All @@ -14,7 +16,7 @@ class ClassForTesting3(ClassForTesting):


class ClassForTesting4(ClassForTesting2):
def use_this_func(self, a):
def use_this_func(self, a: int) -> int:
b = a * 10
c = b + 10
return c
4 changes: 2 additions & 2 deletions inheritance_explorer/_widget_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def update_source_2(event):

class_dropdown_1.observe(update_source_1, ["value"])
class_dropdown_2.observe(update_source_2, ["value"])
update_source_1(None)
update_source_2(None)
update_source_1(None) # type: ignore[no-untyped-call]
update_source_2(None) # type: ignore[no-untyped-call]

if class_1_name is not None:
class_dropdown_1.value = class_1_name
Expand Down
135 changes: 82 additions & 53 deletions inheritance_explorer/inheritance_explorer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import collections
import inspect
import textwrap
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, OrderedDict

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import numpy.typing as npt
import pydot
from matplotlib.axes import Axes
from matplotlib.colors import rgb2hex
Expand All @@ -24,7 +25,7 @@
color: Optional[str] = "#000000",
):
self.child = child
self.child_name = child.__name__
self.child_name: str = str(child.__name__)
self._child_id = child_id
self.parent = parent

Expand All @@ -41,10 +42,13 @@
return str(self._child_id)

@property
def parent_id(self) -> str:
def parent_id(self) -> str | None:
if self._parent_id:
return str(self._parent_id)
return
return None


_similarity_container_types = PycodeSimilarity


class ClassGraphTree:
Expand Down Expand Up @@ -76,38 +80,44 @@
self,
baseclass: Any,
funcname: Optional[str] = None,
default_color: Optional[str] = "#000000",
func_override_color: Optional[str] = "#ff0000",
similarity_cutoff: Optional[float] = 0.75,
max_recursion_level: Optional[int] = 500,
classes_to_exclude: Optional[List[str]] = None,
default_color: str = "#000000",
func_override_color: str = "#ff0000",
similarity_cutoff: float = 0.75,
max_recursion_level: int = 500,
classes_to_exclude: Optional[list[str]] = None,
):

self.baseclass = baseclass
self.basename: str = baseclass.__name__
self.funcname = funcname
self._tracking_function = self.funcname is not None
self.max_recursion_level = max_recursion_level
self._nodenum: int = 0
self._node_list = [] # a list of unique ChildNodes
self._node_map = {} # map of global node index to node name
self._override_src = collections.OrderedDict()
self._override_src_files = {}
self._node_list: list[_ChildNode] = [] # a list of unique ChildNodes
self._node_map: dict[int, str] = {} # map of global node index to node name
self._override_src: OrderedDict[int, str] = collections.OrderedDict()
self._override_src_files: dict[int, str] = {}
self._current_node = 1 # the current global node, must start at 1
self._default_color = default_color
self._override_color = func_override_color
self._graphviz_args_kwargs = {}
self.similarity_container = None
self.similarity_results = None
self._graphviz_args_kwargs: dict[str, Any] = {}
self.similarity_container: _similarity_container_types | None = None
self.similarity_results: dict[str, npt.NDArray[Any]]
self.similarity_cutoff = similarity_cutoff
if classes_to_exclude is None:
classes_to_exclude = []
self.classes_to_exclude = classes_to_exclude
self._build()
self._node_map_r = {v: k for k, v in self._node_map.items()} # name to index
self._node_map_r: dict[str, int] = {
v: k for k, v in self._node_map.items()
} # name to index

def _get_source_info(self, obj) -> Optional[str]:
f = getattr(obj, self.funcname)
if isinstance(f, collections.abc.Callable):
if self.funcname is None:
raise RuntimeError("this functionality requires function tracking.")
fname: str = self.funcname
f = getattr(obj, fname)
if isinstance(f, collections.abc.Callable): # type: ignore[arg-type]
return f"{inspect.getsourcefile(f)}:{inspect.getsourcelines(f)[1]}"
return None

Expand Down Expand Up @@ -157,8 +167,12 @@
# store the source code of funcname for the current class and node
# clss: a class
# current_node: the
f = getattr(clss, self.funcname)
if isinstance(f, collections.abc.Callable):
if self.funcname is None:
raise RuntimeError("this functionality requires function tracking.")
fname: str = self.funcname

f = getattr(clss, fname)
if isinstance(f, collections.abc.Callable): # type: ignore[arg-type]
src = textwrap.dedent(inspect.getsource(f))
self._override_src_files[current_node] = (
f"{inspect.getsourcefile(f)}:{inspect.getsourcelines(f)[1]}"
Expand All @@ -167,21 +181,28 @@

def check_source_similarity(
self,
SimilarityContainer=PycodeSimilarity,
method="reference",
similarity_container_class: str = "PycodeSimilarity",
method: str = "reference",
reference: Optional[int] = None,
):
# compares all the source code of the child methods that have
# over-ridden funcname

if reference is None:
reference = 1 # use whatever the basenode is
ref = 1 # use whatever the basenode is
else:
ref = reference

if similarity_container_class == "PycodeSimilarity":
SimClass = PycodeSimilarity
else:
raise ValueError(f"unexpected value, {similarity_container_class=}")

self.similarity_container = SimilarityContainer(method=method)
sim = self.similarity_container.run(self._override_src, reference=reference)
self.similarity_container = SimClass(method=method)
sim = self.similarity_container.run(self._override_src, reference=ref)
return sim

def _build(self):
def _build(self) -> None:

# construct the first node
color = self._get_baseclass_color()
Expand All @@ -201,11 +222,12 @@
# construct the full similarity matrix
s_c = PycodeSimilarity(method="permute")
_, sim_matrix, sim_axis = s_c.run(self._override_src)
sim_axis = np.array(sim_axis)
assert isinstance(sim_matrix, np.ndarray)
sim_axis_array = np.array(sim_axis)
sim_axis_names = np.array([c.child_name for c in self._node_list])
self.similarity_results = {
"matrix": sim_matrix,
"axis": sim_axis,
"axis": sim_axis_array,
"axis_names": sim_axis_names,
}

Expand All @@ -216,9 +238,9 @@
rowvals = M[irow, :]
indxs = np.where(rowvals >= cutoff_sim)[0]
indxs = indxs[indxs != irow] # these are matrix indeces
node_ids = sim_axis[indxs]
node_ids = sim_axis_array[indxs]
if len(node_ids) > 0:
this_child = sim_axis[irow]
this_child = sim_axis_array[irow]
similarity_sets[this_child] = set(node_ids.tolist())
self.similarity_sets = similarity_sets

Expand Down Expand Up @@ -292,7 +314,7 @@
ax: Optional[Axes] = None,
colorbar: Optional[bool] = True,
**kwargs,
) -> Tuple[dict, Axes]:
) -> tuple[dict[int, str], Axes]:
"""
add the similarity plot to a matplotlib axis (or create a new one)

Expand Down Expand Up @@ -340,16 +362,16 @@
sim_labels = [
self._node_list[cid - 1].child_name for cid in self._override_src.keys()
]
sim_labels = {lid: label for lid, label in enumerate(sim_labels)}
return sim_labels, ax
sim_labels_dict = {lid: label for lid, label in enumerate(sim_labels)}
return sim_labels_dict, ax

def build_interactive_graph(
self,
include_similarity: bool = True,
node_style: dict = None,
edge_style: dict = None,
similarity_edge_style: dict = None,
override_node_color: Union[str, tuple] = None,
node_style: dict[str, Any] | None = None,
edge_style: dict[str, Any] | None = None,
similarity_edge_style: dict[str, Any] | None = None,
override_node_color: str | tuple[float, ...] | None = None,
**kwargs,
) -> Network:
"""
Expand Down Expand Up @@ -454,34 +476,41 @@
network_wrapper.from_nx(grph)
return network_wrapper

def get_source_code(self, node: Union[str, int]) -> str:
def get_source_code(self, node: int | str) -> str:
"""
retrieve the source code of the comparison function for a
specified node

Parameters
----------
node: Union[str, int]
node: int
the node to fetch the source code for

Returns
-------
str
a string containing the source code for the node.
"""
if node in self._override_src:
return self._override_src[node]
node_id: int

if not isinstance(node, int) and not isinstance(node, str):
raise TypeError("Unexpected type for node")

Check warning on line 497 in inheritance_explorer/inheritance_explorer.py

View check run for this annotation

Codecov / codecov/patch

inheritance_explorer/inheritance_explorer.py#L497

Added line #L497 was not covered by tests

if isinstance(node, int) and node in self._node_map:
node_id = node
elif isinstance(node, str) and node in self._node_map_r:
node_id = self._node_map_r[node]
if node_id in self._override_src:
return self._override_src[node_id]
else:
raise ValueError(
f"node {node} does not override the " f"chosen function."
)
raise KeyError(f"Could not find node for {node}")
else:
raise ValueError(f"Could not find node for {node}")

if node_id in self._override_src:
return self._override_src[node_id]
else:
raise ValueError(f"node {node} does not override the chosen function.")

def get_multiple_source_code(self, node_1: Union[str, int], *args) -> dict:
def get_multiple_source_code(
self, node_1: int | str, *args
) -> dict[int | str, str]:
"""
Retrieve the source code for multiple nodes

Expand Down Expand Up @@ -515,11 +544,11 @@
display_code_compare(self)


def _validate_color(clr, default_rgb_tuple: tuple) -> str:
def _validate_color(clr, default_rgb_tuple: tuple[float, float, float]) -> str:
if clr is None:
return rgb2hex(default_rgb_tuple)
return str(rgb2hex(default_rgb_tuple))
elif isinstance(clr, tuple):
return rgb2hex(clr)
return str(rgb2hex(clr))
elif isinstance(clr, str):
return clr
msg = f"clr has unexpected type: {type(clr)}"
Expand Down
Loading