|
2 | 2 | import os |
3 | 3 | import re |
4 | 4 | from pathlib import Path |
| 5 | +from typing import Optional, Tuple |
5 | 6 | from urllib.parse import urlparse, unquote |
6 | 7 | import pandas |
7 | 8 |
|
@@ -46,3 +47,83 @@ def read_csv_cached( |
46 | 47 | df = pandas.read_csv(filepath_or_buffer, **kwargs) |
47 | 48 | df.to_csv(cache_name, index=False) |
48 | 49 | return df |
| 50 | + |
| 51 | + |
| 52 | +def plot_waterfall( |
| 53 | + data: pandas.DataFrame, |
| 54 | + value_column: str, |
| 55 | + label_column: Optional[str] = None, |
| 56 | + total_label: str = "total", |
| 57 | + ax=None, |
| 58 | + colors: Tuple[str, str, str] = ("#2ca02c", "#d62728", "#1f77b4"), |
| 59 | +): |
| 60 | + """ |
| 61 | + Draws a waterfall chart from a dataframe. |
| 62 | +
|
| 63 | + :param data: dataframe containing increments |
| 64 | + :param value_column: column with increments |
| 65 | + :param label_column: column with labels, index is used if None |
| 66 | + :param total_label: label used for the final total |
| 67 | + :param ax: existing axis or None to create one |
| 68 | + :param colors: positive, negative, total colors |
| 69 | + :return: axis, computed dataframe used to draw the chart |
| 70 | +
|
| 71 | + .. versionadded:: 0.6.1 |
| 72 | + """ |
| 73 | + if value_column not in data.columns: |
| 74 | + raise ValueError(f"Unable to find column {value_column!r} in dataframe.") |
| 75 | + if label_column is not None and label_column not in data.columns: |
| 76 | + raise ValueError(f"Unable to find column {label_column!r} in dataframe.") |
| 77 | + if len(colors) != 3: |
| 78 | + raise ValueError(f"colors must contain 3 values, not {len(colors)}.") |
| 79 | + |
| 80 | + values = pandas.to_numeric(data[value_column], errors="raise").astype(float) |
| 81 | + labels = data[label_column] if label_column is not None else data.index |
| 82 | + labels = labels.astype(str) |
| 83 | + |
| 84 | + starts = values.cumsum().shift(1, fill_value=0.0) |
| 85 | + plot_df = pandas.DataFrame( |
| 86 | + { |
| 87 | + "label": labels, |
| 88 | + "value": values, |
| 89 | + "start": starts, |
| 90 | + "end": starts + values, |
| 91 | + "kind": "variation", |
| 92 | + } |
| 93 | + ) |
| 94 | + |
| 95 | + total = float(values.sum()) if len(values) > 0 else 0.0 |
| 96 | + total_row = pandas.DataFrame( |
| 97 | + { |
| 98 | + "label": [total_label], |
| 99 | + "value": [total], |
| 100 | + "start": [0.0], |
| 101 | + "end": [total], |
| 102 | + "kind": ["total"], |
| 103 | + } |
| 104 | + ) |
| 105 | + plot_df = pandas.concat([plot_df, total_row], axis=0, ignore_index=True) |
| 106 | + |
| 107 | + if ax is None: |
| 108 | + import matplotlib.pyplot as plt |
| 109 | + |
| 110 | + _, ax = plt.subplots(1, 1) |
| 111 | + |
| 112 | + bar_colors = [ |
| 113 | + colors[2] |
| 114 | + if kind == "total" |
| 115 | + else (colors[0] if value >= 0 else colors[1]) |
| 116 | + for value, kind in zip(plot_df["value"], plot_df["kind"]) |
| 117 | + ] |
| 118 | + ax.bar( |
| 119 | + plot_df["label"], |
| 120 | + plot_df["value"], |
| 121 | + bottom=plot_df["start"], |
| 122 | + color=bar_colors, |
| 123 | + ) |
| 124 | + |
| 125 | + ax.axhline(0, color="black", linewidth=0.8) |
| 126 | + ax.set_ylabel(value_column) |
| 127 | + ax.set_xlabel(label_column or "index") |
| 128 | + |
| 129 | + return ax, plot_df |
0 commit comments