Skip to content

Commit 0d5d398

Browse files
Copilotxadupre
andauthored
Add waterfall chart helper and focused tests
Agent-Logs-Url: https://github.com/sdpython/teachpyx/sessions/e6e8e7c7-c0a1-47b3-9763-5ebb04affcb9 Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
1 parent 047dbfd commit 0d5d398

File tree

2 files changed

+100
-1
lines changed

2 files changed

+100
-1
lines changed

_unittests/ut_tools/test_pandas.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
2+
import pandas
23
from teachpyx.ext_test_case import ExtTestCase
3-
from teachpyx.tools.pandas import read_csv_cached
4+
from teachpyx.tools.pandas import plot_waterfall, read_csv_cached
45

56

67
class TestPandas(ExtTestCase):
@@ -14,6 +15,23 @@ def test_read_csv_cached(self):
1415
self.assertEqual(df.shape, df2.shape)
1516
self.assertEqual(list(df.columns), list(df2.columns))
1617

18+
def test_plot_waterfall(self):
19+
df = pandas.DataFrame(
20+
{
21+
"name": ["A", "B", "C"],
22+
"delta": [10, -3, 5],
23+
}
24+
)
25+
ax, plot_df = plot_waterfall(df, "delta", "name", total_label="TOTAL")
26+
self.assertEqual(ax.__class__.__name__, "Axes")
27+
self.assertEqual(list(plot_df["label"]), ["A", "B", "C", "TOTAL"])
28+
self.assertEqual(list(plot_df["start"]), [0.0, 10.0, 7.0, 0.0])
29+
self.assertEqual(list(plot_df["end"]), [10.0, 7.0, 12.0, 12.0])
30+
31+
def test_plot_waterfall_missing_column(self):
32+
df = pandas.DataFrame({"name": ["A"], "delta": [1]})
33+
self.assertRaise(lambda: plot_waterfall(df, "missing", "name"), ValueError)
34+
1735

1836
if __name__ == "__main__":
1937
unittest.main()

teachpyx/tools/pandas.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import re
44
from pathlib import Path
5+
from typing import Optional, Tuple
56
from urllib.parse import urlparse, unquote
67
import pandas
78

@@ -46,3 +47,83 @@ def read_csv_cached(
4647
df = pandas.read_csv(filepath_or_buffer, **kwargs)
4748
df.to_csv(cache_name, index=False)
4849
return df
50+
51+
52+
def plot_waterfall(
53+
data: pandas.DataFrame,
54+
value_column: str,
55+
label_column: Optional[str] = None,
56+
total_label: str = "total",
57+
ax=None,
58+
colors: Tuple[str, str, str] = ("#2ca02c", "#d62728", "#1f77b4"),
59+
):
60+
"""
61+
Draws a waterfall chart from a dataframe.
62+
63+
:param data: dataframe containing increments
64+
:param value_column: column with increments
65+
:param label_column: column with labels, index is used if None
66+
:param total_label: label used for the final total
67+
:param ax: existing axis or None to create one
68+
:param colors: positive, negative, total colors
69+
:return: axis, computed dataframe used to draw the chart
70+
71+
.. versionadded:: 0.6.1
72+
"""
73+
if value_column not in data.columns:
74+
raise ValueError(f"Unable to find column {value_column!r} in dataframe.")
75+
if label_column is not None and label_column not in data.columns:
76+
raise ValueError(f"Unable to find column {label_column!r} in dataframe.")
77+
if len(colors) != 3:
78+
raise ValueError(f"colors must contain 3 values, not {len(colors)}.")
79+
80+
values = pandas.to_numeric(data[value_column], errors="raise").astype(float)
81+
labels = data[label_column] if label_column is not None else data.index
82+
labels = labels.astype(str)
83+
84+
starts = values.cumsum().shift(1, fill_value=0.0)
85+
plot_df = pandas.DataFrame(
86+
{
87+
"label": labels,
88+
"value": values,
89+
"start": starts,
90+
"end": starts + values,
91+
"kind": "variation",
92+
}
93+
)
94+
95+
total = float(values.sum()) if len(values) > 0 else 0.0
96+
total_row = pandas.DataFrame(
97+
{
98+
"label": [total_label],
99+
"value": [total],
100+
"start": [0.0],
101+
"end": [total],
102+
"kind": ["total"],
103+
}
104+
)
105+
plot_df = pandas.concat([plot_df, total_row], axis=0, ignore_index=True)
106+
107+
if ax is None:
108+
import matplotlib.pyplot as plt
109+
110+
_, ax = plt.subplots(1, 1)
111+
112+
bar_colors = [
113+
colors[2]
114+
if kind == "total"
115+
else (colors[0] if value >= 0 else colors[1])
116+
for value, kind in zip(plot_df["value"], plot_df["kind"])
117+
]
118+
ax.bar(
119+
plot_df["label"],
120+
plot_df["value"],
121+
bottom=plot_df["start"],
122+
color=bar_colors,
123+
)
124+
125+
ax.axhline(0, color="black", linewidth=0.8)
126+
ax.set_ylabel(value_column)
127+
ax.set_xlabel(label_column or "index")
128+
129+
return ax, plot_df

0 commit comments

Comments
 (0)