Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

v0.2.0dev
---------
* New keyword argument for ``display_code_comparison``, ``include_overrides_only`` which when True (the default), only includes the classes that override the function of interest.
* Improved typing (`PR 42 <https://github.com/data-exp-lab/inheritance_explorer/pull/42>`_)

v0.2.0
Expand Down
4 changes: 4 additions & 0 deletions docs/examples/ex_004_code_comparison.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ The following screenshot shows the code comparison widget in a Jupyter notebook:
.. image:: /resources/inherit_code_widget.gif
:width: 800

By default, the code comparison widget will only display child classes that override
the function being compared. To include all classes, set ``include_overrides_only==False``
when calling ``display_code_comparison``.

23 changes: 19 additions & 4 deletions inheritance_explorer/_widget_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,32 @@ def find_closest_source(cgt: ClassGraphTree, node_id: int):
# it does not, get base class source


def display_code_compare(cgt: ClassGraphTree):
_display_code_compare(cgt)
def display_code_compare(cgt: ClassGraphTree, include_overrides_only: bool = True):
_display_code_compare(cgt, include_overrides_only=include_overrides_only)


def _get_class_names(
cgt: ClassGraphTree, include_overrides_only: bool = True
) -> list[str]:
if include_overrides_only is False:
names_classes = [nm for nm in cgt._node_map_r.keys()]
else:
names_classes = []
for node_id in cgt._override_src.keys():
names_classes.append(cgt._node_map[node_id])

names_classes.sort()
return names_classes


def _display_code_compare(
cgt: ClassGraphTree,
class_1_name: Optional[str] = None,
class_2_name: Optional[str] = None,
include_overrides_only: bool = True,
):
names_classes = [i.child_name for i in cgt._node_list]
names_classes.sort()

names_classes = _get_class_names(cgt, include_overrides_only=include_overrides_only)

class_dropdown_1 = ipywidgets.Dropdown(options=names_classes.copy())
class_dropdown_2 = ipywidgets.Dropdown(options=names_classes.copy())
Expand Down
10 changes: 8 additions & 2 deletions inheritance_explorer/inheritance_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,16 +532,22 @@ def get_multiple_source_code(
src_dict[src_key] = self.get_source_code(src_key)
return src_dict

def display_code_comparison(self):
def display_code_comparison(self, include_overrides_only: bool = True):
"""
show the code comparison widget

Parameters
----------
include_overrides_only: bool
if True (default), only displays the classes that override the function
being compared.
"""

# add a check that we are running from a notebook?
if self.funcname is not None:
from inheritance_explorer._widget_support import display_code_compare

display_code_compare(self)
display_code_compare(self, include_overrides_only=include_overrides_only)


def _validate_color(clr, default_rgb_tuple: tuple[float, float, float]) -> str:
Expand Down
16 changes: 14 additions & 2 deletions inheritance_explorer/tests/test_widgets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from inheritance_explorer._testing import ClassForTesting
from inheritance_explorer._widget_support import _display_code_compare
from inheritance_explorer._widget_support import _display_code_compare, _get_class_names
from inheritance_explorer.inheritance_explorer import ClassGraphTree


Expand All @@ -16,4 +16,16 @@ def test_code_comparison_widget_from_cgt(cgt):

def test_secret_code_comparison_widget(cgt):
_display_code_compare(cgt, class_1_name="ClassForTesting4")
_display_code_compare(cgt, class_2_name="ClassForTesting3")
_display_code_compare(
cgt, class_2_name="ClassForTesting3", include_overrides_only=False
)


def test_get_class_names(cgt):

cnames_all = _get_class_names(cgt, include_overrides_only=False)
assert len(cnames_all) == len(cgt._node_list)

cnames_override = _get_class_names(cgt, include_overrides_only=True)
assert len(cnames_override) < len(cgt._node_list)
assert len(cnames_override) == len(cgt._override_src)