diff --git a/_unittests/ut_faq/test_faq_python.py b/_unittests/ut_faq/test_faq_python.py index e2e0f364..5108e8e8 100644 --- a/_unittests/ut_faq/test_faq_python.py +++ b/_unittests/ut_faq/test_faq_python.py @@ -1,6 +1,14 @@ import unittest import datetime -from teachpyx.faq.faq_python import get_month_name, get_day_name, class_getitem +import matplotlib +from teachpyx.faq.faq_python import ( + class_getitem, + get_day_name, + get_month_name, + graph_sankey, +) + +matplotlib.use("Agg") class TestFaqPython(unittest.TestCase): @@ -17,6 +25,27 @@ def test_day_name(self): def test_class_getitem(self): class_getitem() + def test_graph_sankey(self): + ax, diagrams = graph_sankey( + [1, -0.25, -0.75], + labels=["input", "loss", "output"], + orientations=[0, 1, -1], + trunklength=2.0, + title="flux", + ) + self.assertEqual(ax.get_title(), "flux") + self.assertEqual(len(diagrams), 1) + + def test_graph_sankey_errors(self): + with self.assertRaisesRegex(ValueError, "at least two values"): + graph_sankey([1]) + with self.assertRaisesRegex(ValueError, "must be 0"): + graph_sankey([1, -0.5]) + with self.assertRaisesRegex(ValueError, "same length"): + graph_sankey([1, -1], labels=["only_one"]) + with self.assertRaisesRegex(ValueError, "same length"): + graph_sankey([1, -1], orientations=[0, 1, -1]) + if __name__ == "__main__": unittest.main() diff --git a/teachpyx/faq/faq_python.py b/teachpyx/faq/faq_python.py index 9e4fa380..bd965bda 100644 --- a/teachpyx/faq/faq_python.py +++ b/teachpyx/faq/faq_python.py @@ -4,6 +4,8 @@ import os import re +_FLOW_BALANCE_TOLERANCE = 1e-10 + def entier_grande_taille(): """ @@ -850,3 +852,68 @@ def __init__(self): a = A[2]() assert a.__class__.__name__ == "A2" + + +def graph_sankey( + flows, labels=None, orientations=None, ax=None, title=None, **kwargs +): + """ + Draws a :epkg:`Sankey` graph to represent flows. + + :param flows: list of positive/negative flows + :param labels: labels associated to each flow + :param orientations: list of orientations (-1, 0, 1) + :param ax: axis to use, if None, creates one + :param title: graph title + :param kwargs: additional parameters forwarded to + ``matplotlib.sankey.Sankey.add`` + :return: axis, sankey diagrams + + Example:: + + import matplotlib.pyplot as plt + from teachpyx.faq.faq_python import graph_sankey + + ax, _ = graph_sankey( + [1, -0.25, -0.75], + labels=["input", "loss", "output"], + orientations=[0, 1, -1], + title="flux", + ) + plt.show() + """ + if len(flows) < 2: + raise ValueError(f"flows must contain at least two values, got {len(flows)}.") + total_flow = sum(flows) + if abs(total_flow) > _FLOW_BALANCE_TOLERANCE: + raise ValueError( + "The sum of all flows must be 0 " + f"(within tolerance {_FLOW_BALANCE_TOLERANCE}), got {total_flow}." + ) + if labels is None: + labels = [None] * len(flows) + if orientations is None: + orientations = [0] * len(flows) + if len(labels) != len(flows): + raise ValueError( + "labels and flows must have the same length, " + f"got {len(labels)} and {len(flows)}." + ) + if len(orientations) != len(flows): + raise ValueError( + "orientations and flows must have the same length, " + f"got {len(orientations)} and {len(flows)}." + ) + + import matplotlib.pyplot as plt + from matplotlib.sankey import Sankey + + if ax is None: + _, ax = plt.subplots(1, 1) + + sankey = Sankey(ax=ax) + sankey.add(flows=flows, labels=labels, orientations=orientations, **kwargs) + diagrams = sankey.finish() + if title is not None: + ax.set_title(title) + return ax, diagrams