|
| 1 | +"""pyplots.ai |
| 2 | +heatmap-cohort-retention: Cohort Retention Heatmap |
| 3 | +Library: pygal | Python 3.13 |
| 4 | +Quality: pending | Created: 2026-03-16 |
| 5 | +""" |
| 6 | + |
| 7 | +import sys |
| 8 | + |
| 9 | +import numpy as np |
| 10 | + |
| 11 | + |
| 12 | +# Import pygal avoiding name collision with this filename |
| 13 | +_cwd = sys.path[0] if sys.path[0] else "." |
| 14 | +if _cwd in sys.path: |
| 15 | + sys.path.remove(_cwd) |
| 16 | + |
| 17 | +from pygal.graph.graph import Graph # noqa: E402 |
| 18 | +from pygal.style import Style # noqa: E402 |
| 19 | + |
| 20 | + |
| 21 | +sys.path.insert(0, _cwd) |
| 22 | + |
| 23 | + |
| 24 | +class CohortRetentionHeatmap(Graph): |
| 25 | + _series_margin = 0 |
| 26 | + |
| 27 | + def __init__(self, *args, **kwargs): |
| 28 | + self.matrix_data = kwargs.pop("matrix_data", []) |
| 29 | + self.row_labels = kwargs.pop("row_labels", []) |
| 30 | + self.col_labels = kwargs.pop("col_labels", []) |
| 31 | + self.cohort_sizes = kwargs.pop("cohort_sizes", []) |
| 32 | + self.colormap = kwargs.pop("colormap", []) |
| 33 | + super().__init__(*args, **kwargs) |
| 34 | + |
| 35 | + def _interpolate_color(self, value, min_val, max_val): |
| 36 | + if max_val == min_val: |
| 37 | + return self.colormap[len(self.colormap) // 2] |
| 38 | + normalized = max(0, min(1, (value - min_val) / (max_val - min_val))) |
| 39 | + pos = normalized * (len(self.colormap) - 1) |
| 40 | + idx1 = int(pos) |
| 41 | + idx2 = min(idx1 + 1, len(self.colormap) - 1) |
| 42 | + frac = pos - idx1 |
| 43 | + c1, c2 = self.colormap[idx1], self.colormap[idx2] |
| 44 | + r = int(int(c1[1:3], 16) + (int(c2[1:3], 16) - int(c1[1:3], 16)) * frac) |
| 45 | + g = int(int(c1[3:5], 16) + (int(c2[3:5], 16) - int(c1[3:5], 16)) * frac) |
| 46 | + b = int(int(c1[5:7], 16) + (int(c2[5:7], 16) - int(c1[5:7], 16)) * frac) |
| 47 | + return f"#{r:02x}{g:02x}{b:02x}" |
| 48 | + |
| 49 | + def _get_text_color(self, bg_color): |
| 50 | + r = int(bg_color[1:3], 16) |
| 51 | + g = int(bg_color[3:5], 16) |
| 52 | + b = int(bg_color[5:7], 16) |
| 53 | + brightness = (r * 299 + g * 587 + b * 114) / 1000 |
| 54 | + return "#ffffff" if brightness < 140 else "#222222" |
| 55 | + |
| 56 | + def _plot(self): |
| 57 | + if not self.matrix_data: |
| 58 | + return |
| 59 | + |
| 60 | + n_rows = len(self.matrix_data) |
| 61 | + n_cols = max(len(row) for row in self.matrix_data) |
| 62 | + |
| 63 | + # Find min/max excluding None and period 0 (always 100%) |
| 64 | + non_null = [v for row in self.matrix_data for v in row if v is not None] |
| 65 | + min_val = min(non_null) |
| 66 | + max_val = max(non_null) |
| 67 | + |
| 68 | + plot_width = self.view.width |
| 69 | + plot_height = self.view.height |
| 70 | + |
| 71 | + label_margin_left = 460 |
| 72 | + label_margin_right = 280 |
| 73 | + label_margin_top = 130 |
| 74 | + label_margin_bottom = 10 |
| 75 | + |
| 76 | + available_width = plot_width - label_margin_left - label_margin_right |
| 77 | + available_height = plot_height - label_margin_top - label_margin_bottom |
| 78 | + |
| 79 | + cell_width = available_width / n_cols |
| 80 | + cell_height = available_height / (n_rows + 0.5) |
| 81 | + gap = 3 |
| 82 | + |
| 83 | + grid_width = n_cols * (cell_width + gap) - gap |
| 84 | + grid_height = n_rows * (cell_height + gap) - gap |
| 85 | + |
| 86 | + x_offset = self.view.x(0) + label_margin_left + (available_width - grid_width) / 2 |
| 87 | + y_offset = self.view.y(n_rows) + label_margin_top + (available_height - grid_height - cell_height * 0.5) / 2 |
| 88 | + |
| 89 | + plot_node = self.nodes["plot"] |
| 90 | + |
| 91 | + # Column headers title |
| 92 | + col_font_size = min(38, int(cell_width * 0.48)) |
| 93 | + header_title_y = y_offset - 80 |
| 94 | + header_title_x = x_offset + grid_width / 2 |
| 95 | + text_node = self.svg.node(plot_node, "text", x=header_title_x, y=header_title_y) |
| 96 | + text_node.set("text-anchor", "middle") |
| 97 | + text_node.set("fill", "#555555") |
| 98 | + text_node.set("style", f"font-size:{col_font_size + 4}px;font-weight:600;font-family:sans-serif") |
| 99 | + text_node.text = "Months Since Signup" |
| 100 | + |
| 101 | + # Column headers |
| 102 | + for j, label in enumerate(self.col_labels): |
| 103 | + cx = x_offset + j * (cell_width + gap) + cell_width / 2 |
| 104 | + cy = y_offset - 18 |
| 105 | + text_node = self.svg.node(plot_node, "text", x=cx, y=cy) |
| 106 | + text_node.set("text-anchor", "middle") |
| 107 | + text_node.set("fill", "#333333") |
| 108 | + text_node.set("style", f"font-size:{col_font_size}px;font-weight:700;font-family:sans-serif") |
| 109 | + text_node.text = str(label) |
| 110 | + |
| 111 | + # Row labels with cohort sizes |
| 112 | + row_font_size = min(36, int(cell_height * 0.50)) |
| 113 | + size_font_size = int(row_font_size * 0.78) |
| 114 | + for i, label in enumerate(self.row_labels): |
| 115 | + ry = y_offset + i * (cell_height + gap) + cell_height / 2 |
| 116 | + rx = x_offset - 20 |
| 117 | + |
| 118 | + # Cohort label |
| 119 | + text_node = self.svg.node(plot_node, "text", x=rx, y=ry + row_font_size * 0.15) |
| 120 | + text_node.set("text-anchor", "end") |
| 121 | + text_node.set("fill", "#333333") |
| 122 | + text_node.set("style", f"font-size:{row_font_size}px;font-weight:600;font-family:sans-serif") |
| 123 | + text_node.text = str(label) |
| 124 | + |
| 125 | + # Cohort size below label |
| 126 | + if i < len(self.cohort_sizes): |
| 127 | + text_node = self.svg.node(plot_node, "text", x=rx, y=ry + row_font_size * 0.15 + size_font_size + 4) |
| 128 | + text_node.set("text-anchor", "end") |
| 129 | + text_node.set("fill", "#888888") |
| 130 | + text_node.set("style", f"font-size:{size_font_size}px;font-weight:400;font-family:sans-serif") |
| 131 | + text_node.text = f"n={self.cohort_sizes[i]:,}" |
| 132 | + |
| 133 | + # Row label title (rotated) |
| 134 | + row_title_x = x_offset - 340 |
| 135 | + row_title_y = y_offset + grid_height / 2 |
| 136 | + text_node = self.svg.node(plot_node, "text", x=row_title_x, y=row_title_y) |
| 137 | + text_node.set("text-anchor", "middle") |
| 138 | + text_node.set("fill", "#555555") |
| 139 | + text_node.set("style", f"font-size:{col_font_size + 4}px;font-weight:600;font-family:sans-serif") |
| 140 | + text_node.set("transform", f"rotate(-90, {row_title_x}, {row_title_y})") |
| 141 | + text_node.text = "Signup Cohort" |
| 142 | + |
| 143 | + # Draw cells |
| 144 | + value_font_size = min(38, int(min(cell_width, cell_height) * 0.46)) |
| 145 | + for i in range(n_rows): |
| 146 | + for j in range(len(self.matrix_data[i])): |
| 147 | + value = self.matrix_data[i][j] |
| 148 | + if value is None: |
| 149 | + continue |
| 150 | + |
| 151 | + color = self._interpolate_color(value, min_val, max_val) |
| 152 | + text_color = self._get_text_color(color) |
| 153 | + |
| 154 | + cx = x_offset + j * (cell_width + gap) |
| 155 | + cy = y_offset + i * (cell_height + gap) |
| 156 | + |
| 157 | + cell_group = self.svg.node(plot_node, "g", class_="cell") |
| 158 | + rect = self.svg.node(cell_group, "rect", x=cx, y=cy, width=cell_width, height=cell_height, rx=3, ry=3) |
| 159 | + rect.set("fill", color) |
| 160 | + rect.set("stroke", "#ffffff") |
| 161 | + rect.set("stroke-width", "2") |
| 162 | + |
| 163 | + # Tooltip |
| 164 | + cohort_label = self.row_labels[i] if i < len(self.row_labels) else "" |
| 165 | + period_label = self.col_labels[j] if j < len(self.col_labels) else "" |
| 166 | + self._tooltip_data( |
| 167 | + cell_group, |
| 168 | + f"{value:.1f}%", |
| 169 | + cx + cell_width / 2, |
| 170 | + cy + cell_height / 2, |
| 171 | + xlabel=f"{cohort_label} / {period_label}", |
| 172 | + ) |
| 173 | + |
| 174 | + # Value text |
| 175 | + tx = cx + cell_width / 2 |
| 176 | + ty = cy + cell_height / 2 + value_font_size * 0.35 |
| 177 | + text_node = self.svg.node(cell_group, "text", x=tx, y=ty) |
| 178 | + text_node.set("text-anchor", "middle") |
| 179 | + text_node.set("fill", text_color) |
| 180 | + text_node.set("style", f"font-size:{value_font_size}px;font-weight:500;font-family:sans-serif") |
| 181 | + text_node.text = f"{value:.0f}%" |
| 182 | + |
| 183 | + # Colorbar |
| 184 | + cb_width = 46 |
| 185 | + cb_height = grid_height * 0.75 |
| 186 | + cb_x = x_offset + grid_width + 50 |
| 187 | + cb_y = y_offset + (grid_height - cb_height) / 2 |
| 188 | + |
| 189 | + defs = self.svg.node(plot_node, "defs") |
| 190 | + gradient = self.svg.node(defs, "linearGradient", id="cb-gradient", x1="0", y1="0", x2="0", y2="1") |
| 191 | + for frac_i in range(21): |
| 192 | + frac = frac_i / 20.0 |
| 193 | + val = max_val - (max_val - min_val) * frac |
| 194 | + color = self._interpolate_color(val, min_val, max_val) |
| 195 | + stop = self.svg.node(gradient, "stop", offset=f"{frac * 100}%") |
| 196 | + stop.set("stop-color", color) |
| 197 | + |
| 198 | + cb_rect = self.svg.node(plot_node, "rect", x=cb_x, y=cb_y, width=cb_width, height=cb_height, rx=4, ry=4) |
| 199 | + cb_rect.set("fill", "url(#cb-gradient)") |
| 200 | + cb_rect.set("stroke", "#999999") |
| 201 | + cb_rect.set("stroke-width", "1.5") |
| 202 | + |
| 203 | + cb_label_size = 28 |
| 204 | + for frac, val in [ |
| 205 | + (0.0, max_val), |
| 206 | + (0.25, max_val * 0.75 + min_val * 0.25), |
| 207 | + (0.5, (min_val + max_val) / 2), |
| 208 | + (0.75, max_val * 0.25 + min_val * 0.75), |
| 209 | + (1.0, min_val), |
| 210 | + ]: |
| 211 | + ty = cb_y + cb_height * frac |
| 212 | + tick = self.svg.node(plot_node, "line", x1=cb_x + cb_width, y1=ty, x2=cb_x + cb_width + 8, y2=ty) |
| 213 | + tick.set("stroke", "#666666") |
| 214 | + tick.set("stroke-width", "1.5") |
| 215 | + text_node = self.svg.node(plot_node, "text", x=cb_x + cb_width + 14, y=ty + cb_label_size * 0.35) |
| 216 | + text_node.set("fill", "#333333") |
| 217 | + text_node.set("style", f"font-size:{cb_label_size}px;font-family:sans-serif") |
| 218 | + text_node.text = f"{val:.0f}%" |
| 219 | + |
| 220 | + cb_title = self.svg.node(plot_node, "text", x=cb_x + cb_width / 2, y=cb_y - 20) |
| 221 | + cb_title.set("text-anchor", "middle") |
| 222 | + cb_title.set("fill", "#333333") |
| 223 | + cb_title.set("style", f"font-size:{cb_label_size + 2}px;font-weight:600;font-family:sans-serif") |
| 224 | + cb_title.text = "Retention %" |
| 225 | + |
| 226 | + def _compute(self): |
| 227 | + n_rows = len(self.matrix_data) if self.matrix_data else 1 |
| 228 | + n_cols = max(len(row) for row in self.matrix_data) if self.matrix_data else 1 |
| 229 | + self._box.xmin = 0 |
| 230 | + self._box.xmax = n_cols |
| 231 | + self._box.ymin = 0 |
| 232 | + self._box.ymax = n_rows |
| 233 | + |
| 234 | + |
| 235 | +# Data — Monthly signup cohorts with retention rates |
| 236 | +np.random.seed(42) |
| 237 | + |
| 238 | +cohort_labels = [ |
| 239 | + "Jan 2024", |
| 240 | + "Feb 2024", |
| 241 | + "Mar 2024", |
| 242 | + "Apr 2024", |
| 243 | + "May 2024", |
| 244 | + "Jun 2024", |
| 245 | + "Jul 2024", |
| 246 | + "Aug 2024", |
| 247 | + "Sep 2024", |
| 248 | + "Oct 2024", |
| 249 | +] |
| 250 | +n_cohorts = len(cohort_labels) |
| 251 | +n_max_periods = 10 |
| 252 | + |
| 253 | +cohort_sizes = [1200, 1350, 980, 1520, 1100, 1430, 1280, 1050, 1380, 1150] |
| 254 | + |
| 255 | +# Base retention curve that decays over time |
| 256 | +base_retention = np.array([100.0, 68.0, 52.0, 43.0, 37.0, 33.0, 30.0, 28.0, 26.5, 25.0]) |
| 257 | + |
| 258 | +# Build triangular retention matrix |
| 259 | +matrix = [] |
| 260 | +for i in range(n_cohorts): |
| 261 | + n_periods = n_max_periods - i |
| 262 | + row = [] |
| 263 | + for j in range(n_periods): |
| 264 | + if j == 0: |
| 265 | + row.append(100.0) |
| 266 | + else: |
| 267 | + # Add cohort-specific variation — later cohorts slightly better retention |
| 268 | + improvement = i * 0.8 |
| 269 | + noise = np.random.uniform(-2.5, 2.5) |
| 270 | + val = base_retention[j] + improvement + noise |
| 271 | + val = max(5.0, min(100.0, val)) |
| 272 | + row.append(round(val, 1)) |
| 273 | + matrix.append(row) |
| 274 | + |
| 275 | +period_labels = [f"Month {i}" for i in range(n_max_periods)] |
| 276 | + |
| 277 | +# Sequential green colormap (light to dark) |
| 278 | +green_colormap = ["#f7fcf5", "#e5f5e0", "#c7e9c0", "#a1d99b", "#74c476", "#41ab5d", "#238b45", "#005a32"] |
| 279 | + |
| 280 | +# Style |
| 281 | +custom_style = Style( |
| 282 | + background="white", |
| 283 | + plot_background="white", |
| 284 | + foreground="#333333", |
| 285 | + foreground_strong="#333333", |
| 286 | + foreground_subtle="#666666", |
| 287 | + colors=("#238b45",), |
| 288 | + title_font_size=54, |
| 289 | + legend_font_size=28, |
| 290 | + label_font_size=34, |
| 291 | + value_font_size=28, |
| 292 | + tooltip_font_size=26, |
| 293 | + font_family="sans-serif", |
| 294 | +) |
| 295 | + |
| 296 | +# Chart |
| 297 | +chart = CohortRetentionHeatmap( |
| 298 | + width=4800, |
| 299 | + height=2700, |
| 300 | + style=custom_style, |
| 301 | + title="heatmap-cohort-retention \u00b7 pygal \u00b7 pyplots.ai", |
| 302 | + matrix_data=matrix, |
| 303 | + row_labels=cohort_labels, |
| 304 | + col_labels=period_labels, |
| 305 | + cohort_sizes=cohort_sizes, |
| 306 | + colormap=green_colormap, |
| 307 | + show_legend=False, |
| 308 | + margin=100, |
| 309 | + margin_top=200, |
| 310 | + margin_bottom=30, |
| 311 | + margin_left=120, |
| 312 | + margin_right=120, |
| 313 | + show_x_labels=False, |
| 314 | + show_y_labels=False, |
| 315 | +) |
| 316 | + |
| 317 | +chart.add("", [0]) |
| 318 | + |
| 319 | +# Save |
| 320 | +chart.render_to_file("plot.svg") |
| 321 | +chart.render_to_png("plot.png") |
| 322 | + |
| 323 | +html_content = f"""<!DOCTYPE html> |
| 324 | +<html> |
| 325 | +<head> |
| 326 | + <meta charset="utf-8"> |
| 327 | + <title>heatmap-cohort-retention - pygal</title> |
| 328 | + <style> |
| 329 | + body {{ margin: 0; display: flex; justify-content: center; align-items: center; min-height: 100vh; background: #f5f5f5; }} |
| 330 | + .chart {{ max-width: 100%; height: auto; }} |
| 331 | + </style> |
| 332 | +</head> |
| 333 | +<body> |
| 334 | + <figure class="chart"> |
| 335 | + {chart.render(is_unicode=True)} |
| 336 | + </figure> |
| 337 | +</body> |
| 338 | +</html> |
| 339 | +""" |
| 340 | + |
| 341 | +with open("plot.html", "w", encoding="utf-8") as f: |
| 342 | + f.write(html_content) |
0 commit comments