Skip to content

Commit 2ea5f93

Browse files
feat(pygal): implement heatmap-cohort-retention
1 parent 0dd291f commit 2ea5f93

1 file changed

Lines changed: 342 additions & 0 deletions

File tree

  • plots/heatmap-cohort-retention/implementations
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
"""pyplots.ai
2+
heatmap-cohort-retention: Cohort Retention Heatmap
3+
Library: pygal | Python 3.13
4+
Quality: pending | Created: 2026-03-16
5+
"""
6+
7+
import sys
8+
9+
import numpy as np
10+
11+
12+
# Import pygal avoiding name collision with this filename
13+
_cwd = sys.path[0] if sys.path[0] else "."
14+
if _cwd in sys.path:
15+
sys.path.remove(_cwd)
16+
17+
from pygal.graph.graph import Graph # noqa: E402
18+
from pygal.style import Style # noqa: E402
19+
20+
21+
sys.path.insert(0, _cwd)
22+
23+
24+
class CohortRetentionHeatmap(Graph):
25+
_series_margin = 0
26+
27+
def __init__(self, *args, **kwargs):
28+
self.matrix_data = kwargs.pop("matrix_data", [])
29+
self.row_labels = kwargs.pop("row_labels", [])
30+
self.col_labels = kwargs.pop("col_labels", [])
31+
self.cohort_sizes = kwargs.pop("cohort_sizes", [])
32+
self.colormap = kwargs.pop("colormap", [])
33+
super().__init__(*args, **kwargs)
34+
35+
def _interpolate_color(self, value, min_val, max_val):
36+
if max_val == min_val:
37+
return self.colormap[len(self.colormap) // 2]
38+
normalized = max(0, min(1, (value - min_val) / (max_val - min_val)))
39+
pos = normalized * (len(self.colormap) - 1)
40+
idx1 = int(pos)
41+
idx2 = min(idx1 + 1, len(self.colormap) - 1)
42+
frac = pos - idx1
43+
c1, c2 = self.colormap[idx1], self.colormap[idx2]
44+
r = int(int(c1[1:3], 16) + (int(c2[1:3], 16) - int(c1[1:3], 16)) * frac)
45+
g = int(int(c1[3:5], 16) + (int(c2[3:5], 16) - int(c1[3:5], 16)) * frac)
46+
b = int(int(c1[5:7], 16) + (int(c2[5:7], 16) - int(c1[5:7], 16)) * frac)
47+
return f"#{r:02x}{g:02x}{b:02x}"
48+
49+
def _get_text_color(self, bg_color):
50+
r = int(bg_color[1:3], 16)
51+
g = int(bg_color[3:5], 16)
52+
b = int(bg_color[5:7], 16)
53+
brightness = (r * 299 + g * 587 + b * 114) / 1000
54+
return "#ffffff" if brightness < 140 else "#222222"
55+
56+
def _plot(self):
57+
if not self.matrix_data:
58+
return
59+
60+
n_rows = len(self.matrix_data)
61+
n_cols = max(len(row) for row in self.matrix_data)
62+
63+
# Find min/max excluding None and period 0 (always 100%)
64+
non_null = [v for row in self.matrix_data for v in row if v is not None]
65+
min_val = min(non_null)
66+
max_val = max(non_null)
67+
68+
plot_width = self.view.width
69+
plot_height = self.view.height
70+
71+
label_margin_left = 460
72+
label_margin_right = 280
73+
label_margin_top = 130
74+
label_margin_bottom = 10
75+
76+
available_width = plot_width - label_margin_left - label_margin_right
77+
available_height = plot_height - label_margin_top - label_margin_bottom
78+
79+
cell_width = available_width / n_cols
80+
cell_height = available_height / (n_rows + 0.5)
81+
gap = 3
82+
83+
grid_width = n_cols * (cell_width + gap) - gap
84+
grid_height = n_rows * (cell_height + gap) - gap
85+
86+
x_offset = self.view.x(0) + label_margin_left + (available_width - grid_width) / 2
87+
y_offset = self.view.y(n_rows) + label_margin_top + (available_height - grid_height - cell_height * 0.5) / 2
88+
89+
plot_node = self.nodes["plot"]
90+
91+
# Column headers title
92+
col_font_size = min(38, int(cell_width * 0.48))
93+
header_title_y = y_offset - 80
94+
header_title_x = x_offset + grid_width / 2
95+
text_node = self.svg.node(plot_node, "text", x=header_title_x, y=header_title_y)
96+
text_node.set("text-anchor", "middle")
97+
text_node.set("fill", "#555555")
98+
text_node.set("style", f"font-size:{col_font_size + 4}px;font-weight:600;font-family:sans-serif")
99+
text_node.text = "Months Since Signup"
100+
101+
# Column headers
102+
for j, label in enumerate(self.col_labels):
103+
cx = x_offset + j * (cell_width + gap) + cell_width / 2
104+
cy = y_offset - 18
105+
text_node = self.svg.node(plot_node, "text", x=cx, y=cy)
106+
text_node.set("text-anchor", "middle")
107+
text_node.set("fill", "#333333")
108+
text_node.set("style", f"font-size:{col_font_size}px;font-weight:700;font-family:sans-serif")
109+
text_node.text = str(label)
110+
111+
# Row labels with cohort sizes
112+
row_font_size = min(36, int(cell_height * 0.50))
113+
size_font_size = int(row_font_size * 0.78)
114+
for i, label in enumerate(self.row_labels):
115+
ry = y_offset + i * (cell_height + gap) + cell_height / 2
116+
rx = x_offset - 20
117+
118+
# Cohort label
119+
text_node = self.svg.node(plot_node, "text", x=rx, y=ry + row_font_size * 0.15)
120+
text_node.set("text-anchor", "end")
121+
text_node.set("fill", "#333333")
122+
text_node.set("style", f"font-size:{row_font_size}px;font-weight:600;font-family:sans-serif")
123+
text_node.text = str(label)
124+
125+
# Cohort size below label
126+
if i < len(self.cohort_sizes):
127+
text_node = self.svg.node(plot_node, "text", x=rx, y=ry + row_font_size * 0.15 + size_font_size + 4)
128+
text_node.set("text-anchor", "end")
129+
text_node.set("fill", "#888888")
130+
text_node.set("style", f"font-size:{size_font_size}px;font-weight:400;font-family:sans-serif")
131+
text_node.text = f"n={self.cohort_sizes[i]:,}"
132+
133+
# Row label title (rotated)
134+
row_title_x = x_offset - 340
135+
row_title_y = y_offset + grid_height / 2
136+
text_node = self.svg.node(plot_node, "text", x=row_title_x, y=row_title_y)
137+
text_node.set("text-anchor", "middle")
138+
text_node.set("fill", "#555555")
139+
text_node.set("style", f"font-size:{col_font_size + 4}px;font-weight:600;font-family:sans-serif")
140+
text_node.set("transform", f"rotate(-90, {row_title_x}, {row_title_y})")
141+
text_node.text = "Signup Cohort"
142+
143+
# Draw cells
144+
value_font_size = min(38, int(min(cell_width, cell_height) * 0.46))
145+
for i in range(n_rows):
146+
for j in range(len(self.matrix_data[i])):
147+
value = self.matrix_data[i][j]
148+
if value is None:
149+
continue
150+
151+
color = self._interpolate_color(value, min_val, max_val)
152+
text_color = self._get_text_color(color)
153+
154+
cx = x_offset + j * (cell_width + gap)
155+
cy = y_offset + i * (cell_height + gap)
156+
157+
cell_group = self.svg.node(plot_node, "g", class_="cell")
158+
rect = self.svg.node(cell_group, "rect", x=cx, y=cy, width=cell_width, height=cell_height, rx=3, ry=3)
159+
rect.set("fill", color)
160+
rect.set("stroke", "#ffffff")
161+
rect.set("stroke-width", "2")
162+
163+
# Tooltip
164+
cohort_label = self.row_labels[i] if i < len(self.row_labels) else ""
165+
period_label = self.col_labels[j] if j < len(self.col_labels) else ""
166+
self._tooltip_data(
167+
cell_group,
168+
f"{value:.1f}%",
169+
cx + cell_width / 2,
170+
cy + cell_height / 2,
171+
xlabel=f"{cohort_label} / {period_label}",
172+
)
173+
174+
# Value text
175+
tx = cx + cell_width / 2
176+
ty = cy + cell_height / 2 + value_font_size * 0.35
177+
text_node = self.svg.node(cell_group, "text", x=tx, y=ty)
178+
text_node.set("text-anchor", "middle")
179+
text_node.set("fill", text_color)
180+
text_node.set("style", f"font-size:{value_font_size}px;font-weight:500;font-family:sans-serif")
181+
text_node.text = f"{value:.0f}%"
182+
183+
# Colorbar
184+
cb_width = 46
185+
cb_height = grid_height * 0.75
186+
cb_x = x_offset + grid_width + 50
187+
cb_y = y_offset + (grid_height - cb_height) / 2
188+
189+
defs = self.svg.node(plot_node, "defs")
190+
gradient = self.svg.node(defs, "linearGradient", id="cb-gradient", x1="0", y1="0", x2="0", y2="1")
191+
for frac_i in range(21):
192+
frac = frac_i / 20.0
193+
val = max_val - (max_val - min_val) * frac
194+
color = self._interpolate_color(val, min_val, max_val)
195+
stop = self.svg.node(gradient, "stop", offset=f"{frac * 100}%")
196+
stop.set("stop-color", color)
197+
198+
cb_rect = self.svg.node(plot_node, "rect", x=cb_x, y=cb_y, width=cb_width, height=cb_height, rx=4, ry=4)
199+
cb_rect.set("fill", "url(#cb-gradient)")
200+
cb_rect.set("stroke", "#999999")
201+
cb_rect.set("stroke-width", "1.5")
202+
203+
cb_label_size = 28
204+
for frac, val in [
205+
(0.0, max_val),
206+
(0.25, max_val * 0.75 + min_val * 0.25),
207+
(0.5, (min_val + max_val) / 2),
208+
(0.75, max_val * 0.25 + min_val * 0.75),
209+
(1.0, min_val),
210+
]:
211+
ty = cb_y + cb_height * frac
212+
tick = self.svg.node(plot_node, "line", x1=cb_x + cb_width, y1=ty, x2=cb_x + cb_width + 8, y2=ty)
213+
tick.set("stroke", "#666666")
214+
tick.set("stroke-width", "1.5")
215+
text_node = self.svg.node(plot_node, "text", x=cb_x + cb_width + 14, y=ty + cb_label_size * 0.35)
216+
text_node.set("fill", "#333333")
217+
text_node.set("style", f"font-size:{cb_label_size}px;font-family:sans-serif")
218+
text_node.text = f"{val:.0f}%"
219+
220+
cb_title = self.svg.node(plot_node, "text", x=cb_x + cb_width / 2, y=cb_y - 20)
221+
cb_title.set("text-anchor", "middle")
222+
cb_title.set("fill", "#333333")
223+
cb_title.set("style", f"font-size:{cb_label_size + 2}px;font-weight:600;font-family:sans-serif")
224+
cb_title.text = "Retention %"
225+
226+
def _compute(self):
227+
n_rows = len(self.matrix_data) if self.matrix_data else 1
228+
n_cols = max(len(row) for row in self.matrix_data) if self.matrix_data else 1
229+
self._box.xmin = 0
230+
self._box.xmax = n_cols
231+
self._box.ymin = 0
232+
self._box.ymax = n_rows
233+
234+
235+
# Data — Monthly signup cohorts with retention rates
236+
np.random.seed(42)
237+
238+
cohort_labels = [
239+
"Jan 2024",
240+
"Feb 2024",
241+
"Mar 2024",
242+
"Apr 2024",
243+
"May 2024",
244+
"Jun 2024",
245+
"Jul 2024",
246+
"Aug 2024",
247+
"Sep 2024",
248+
"Oct 2024",
249+
]
250+
n_cohorts = len(cohort_labels)
251+
n_max_periods = 10
252+
253+
cohort_sizes = [1200, 1350, 980, 1520, 1100, 1430, 1280, 1050, 1380, 1150]
254+
255+
# Base retention curve that decays over time
256+
base_retention = np.array([100.0, 68.0, 52.0, 43.0, 37.0, 33.0, 30.0, 28.0, 26.5, 25.0])
257+
258+
# Build triangular retention matrix
259+
matrix = []
260+
for i in range(n_cohorts):
261+
n_periods = n_max_periods - i
262+
row = []
263+
for j in range(n_periods):
264+
if j == 0:
265+
row.append(100.0)
266+
else:
267+
# Add cohort-specific variation — later cohorts slightly better retention
268+
improvement = i * 0.8
269+
noise = np.random.uniform(-2.5, 2.5)
270+
val = base_retention[j] + improvement + noise
271+
val = max(5.0, min(100.0, val))
272+
row.append(round(val, 1))
273+
matrix.append(row)
274+
275+
period_labels = [f"Month {i}" for i in range(n_max_periods)]
276+
277+
# Sequential green colormap (light to dark)
278+
green_colormap = ["#f7fcf5", "#e5f5e0", "#c7e9c0", "#a1d99b", "#74c476", "#41ab5d", "#238b45", "#005a32"]
279+
280+
# Style
281+
custom_style = Style(
282+
background="white",
283+
plot_background="white",
284+
foreground="#333333",
285+
foreground_strong="#333333",
286+
foreground_subtle="#666666",
287+
colors=("#238b45",),
288+
title_font_size=54,
289+
legend_font_size=28,
290+
label_font_size=34,
291+
value_font_size=28,
292+
tooltip_font_size=26,
293+
font_family="sans-serif",
294+
)
295+
296+
# Chart
297+
chart = CohortRetentionHeatmap(
298+
width=4800,
299+
height=2700,
300+
style=custom_style,
301+
title="heatmap-cohort-retention \u00b7 pygal \u00b7 pyplots.ai",
302+
matrix_data=matrix,
303+
row_labels=cohort_labels,
304+
col_labels=period_labels,
305+
cohort_sizes=cohort_sizes,
306+
colormap=green_colormap,
307+
show_legend=False,
308+
margin=100,
309+
margin_top=200,
310+
margin_bottom=30,
311+
margin_left=120,
312+
margin_right=120,
313+
show_x_labels=False,
314+
show_y_labels=False,
315+
)
316+
317+
chart.add("", [0])
318+
319+
# Save
320+
chart.render_to_file("plot.svg")
321+
chart.render_to_png("plot.png")
322+
323+
html_content = f"""<!DOCTYPE html>
324+
<html>
325+
<head>
326+
<meta charset="utf-8">
327+
<title>heatmap-cohort-retention - pygal</title>
328+
<style>
329+
body {{ margin: 0; display: flex; justify-content: center; align-items: center; min-height: 100vh; background: #f5f5f5; }}
330+
.chart {{ max-width: 100%; height: auto; }}
331+
</style>
332+
</head>
333+
<body>
334+
<figure class="chart">
335+
{chart.render(is_unicode=True)}
336+
</figure>
337+
</body>
338+
</html>
339+
"""
340+
341+
with open("plot.html", "w", encoding="utf-8") as f:
342+
f.write(html_content)

0 commit comments

Comments
 (0)