diff --git a/HISTORY.rst b/HISTORY.rst index 558bf9a..7b4f07f 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History v0.2.0dev --------- +* New keyword argument for ``display_code_comparison``, ``include_overrides_only`` which when True (the default), only includes the classes that override the function of interest. * Improved typing (`PR 42 `_) v0.2.0 diff --git a/docs/examples/ex_004_code_comparison.rst b/docs/examples/ex_004_code_comparison.rst index d55b35d..8f4beb5 100644 --- a/docs/examples/ex_004_code_comparison.rst +++ b/docs/examples/ex_004_code_comparison.rst @@ -23,3 +23,7 @@ The following screenshot shows the code comparison widget in a Jupyter notebook: .. image:: /resources/inherit_code_widget.gif :width: 800 +By default, the code comparison widget will only display child classes that override +the function being compared. To include all classes, set ``include_overrides_only==False`` +when calling ``display_code_comparison``. + diff --git a/inheritance_explorer/_widget_support.py b/inheritance_explorer/_widget_support.py index 43bd0d4..b6a9bf2 100644 --- a/inheritance_explorer/_widget_support.py +++ b/inheritance_explorer/_widget_support.py @@ -24,17 +24,32 @@ def find_closest_source(cgt: ClassGraphTree, node_id: int): # it does not, get base class source -def display_code_compare(cgt: ClassGraphTree): - _display_code_compare(cgt) +def display_code_compare(cgt: ClassGraphTree, include_overrides_only: bool = True): + _display_code_compare(cgt, include_overrides_only=include_overrides_only) + + +def _get_class_names( + cgt: ClassGraphTree, include_overrides_only: bool = True +) -> list[str]: + if include_overrides_only is False: + names_classes = [nm for nm in cgt._node_map_r.keys()] + else: + names_classes = [] + for node_id in cgt._override_src.keys(): + names_classes.append(cgt._node_map[node_id]) + + names_classes.sort() + return names_classes def _display_code_compare( cgt: ClassGraphTree, class_1_name: Optional[str] = None, class_2_name: Optional[str] = None, + include_overrides_only: bool = True, ): - names_classes = [i.child_name for i in cgt._node_list] - names_classes.sort() + + names_classes = _get_class_names(cgt, include_overrides_only=include_overrides_only) class_dropdown_1 = ipywidgets.Dropdown(options=names_classes.copy()) class_dropdown_2 = ipywidgets.Dropdown(options=names_classes.copy()) diff --git a/inheritance_explorer/inheritance_explorer.py b/inheritance_explorer/inheritance_explorer.py index 35bdc60..125bd7a 100644 --- a/inheritance_explorer/inheritance_explorer.py +++ b/inheritance_explorer/inheritance_explorer.py @@ -532,16 +532,22 @@ def get_multiple_source_code( src_dict[src_key] = self.get_source_code(src_key) return src_dict - def display_code_comparison(self): + def display_code_comparison(self, include_overrides_only: bool = True): """ show the code comparison widget + + Parameters + ---------- + include_overrides_only: bool + if True (default), only displays the classes that override the function + being compared. """ # add a check that we are running from a notebook? if self.funcname is not None: from inheritance_explorer._widget_support import display_code_compare - display_code_compare(self) + display_code_compare(self, include_overrides_only=include_overrides_only) def _validate_color(clr, default_rgb_tuple: tuple[float, float, float]) -> str: diff --git a/inheritance_explorer/tests/test_widgets.py b/inheritance_explorer/tests/test_widgets.py index 209b535..5399bfe 100644 --- a/inheritance_explorer/tests/test_widgets.py +++ b/inheritance_explorer/tests/test_widgets.py @@ -1,7 +1,7 @@ import pytest from inheritance_explorer._testing import ClassForTesting -from inheritance_explorer._widget_support import _display_code_compare +from inheritance_explorer._widget_support import _display_code_compare, _get_class_names from inheritance_explorer.inheritance_explorer import ClassGraphTree @@ -16,4 +16,16 @@ def test_code_comparison_widget_from_cgt(cgt): def test_secret_code_comparison_widget(cgt): _display_code_compare(cgt, class_1_name="ClassForTesting4") - _display_code_compare(cgt, class_2_name="ClassForTesting3") + _display_code_compare( + cgt, class_2_name="ClassForTesting3", include_overrides_only=False + ) + + +def test_get_class_names(cgt): + + cnames_all = _get_class_names(cgt, include_overrides_only=False) + assert len(cnames_all) == len(cgt._node_list) + + cnames_override = _get_class_names(cgt, include_overrides_only=True) + assert len(cnames_override) < len(cgt._node_list) + assert len(cnames_override) == len(cgt._override_src)