Skip to content

Commit 78545e2

Browse files
Merge branch 'main' into implementation/heatmap-cohort-retention/altair
2 parents 9008f4d + d30b48c commit 78545e2

18 files changed

Lines changed: 3407 additions & 0 deletions

File tree

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
""" pyplots.ai
2+
heatmap-cohort-retention: Cohort Retention Heatmap
3+
Library: bokeh 3.9.0 | Python 3.14.3
4+
Quality: 90/100 | Created: 2026-03-16
5+
"""
6+
7+
import numpy as np
8+
from bokeh.io import export_png
9+
from bokeh.models import BasicTicker, ColorBar, ColumnDataSource, HoverTool, Label, LinearColorMapper
10+
from bokeh.plotting import figure
11+
from bokeh.transform import transform
12+
13+
14+
# Data: Monthly signup cohorts with retention tracking
15+
np.random.seed(42)
16+
cohort_labels = [
17+
"Jan 2024",
18+
"Feb 2024",
19+
"Mar 2024",
20+
"Apr 2024",
21+
"May 2024",
22+
"Jun 2024",
23+
"Jul 2024",
24+
"Aug 2024",
25+
"Sep 2024",
26+
"Oct 2024",
27+
]
28+
n_cohorts = len(cohort_labels)
29+
n_periods = 10
30+
cohort_sizes = np.random.randint(800, 2500, size=n_cohorts)
31+
32+
# Generate realistic retention data (triangular shape)
33+
retention = np.full((n_cohorts, n_periods), np.nan)
34+
for i in range(n_cohorts):
35+
max_periods = n_periods - i
36+
retention[i, 0] = 100.0
37+
base_decay = np.random.uniform(0.65, 0.80)
38+
for j in range(1, max_periods):
39+
decay = base_decay + np.random.uniform(-0.05, 0.05)
40+
retention[i, j] = retention[i, j - 1] * decay
41+
retention[i, j] = max(retention[i, j], 2.0)
42+
43+
# Prepare data for bokeh heatmap
44+
x_coords = []
45+
y_coords = []
46+
values = []
47+
text_values = []
48+
text_colors = []
49+
50+
period_labels = [f"Month {i}" for i in range(n_periods)]
51+
y_labels = [f"{label} (n={size:,})" for label, size in zip(cohort_labels, cohort_sizes, strict=True)]
52+
53+
for i in range(n_cohorts):
54+
for j in range(n_periods):
55+
if not np.isnan(retention[i, j]):
56+
x_coords.append(period_labels[j])
57+
y_coords.append(y_labels[i])
58+
val = retention[i, j]
59+
values.append(val)
60+
text_values.append(f"{val:.1f}%")
61+
# Adaptive text color with refined threshold for viridis
62+
text_colors.append("white" if val < 45 else "#1a1a1a")
63+
64+
source = ColumnDataSource(
65+
data={"x": x_coords, "y": y_coords, "value": values, "text": text_values, "text_color": text_colors}
66+
)
67+
68+
# Viridis-inspired palette: perceptually uniform, colorblind-safe
69+
viridis_palette = [
70+
"#440154",
71+
"#482878",
72+
"#3e4989",
73+
"#31688e",
74+
"#26828e",
75+
"#1f9e89",
76+
"#35b779",
77+
"#6ece58",
78+
"#b5de2b",
79+
"#fde725",
80+
]
81+
mapper = LinearColorMapper(palette=viridis_palette, low=0, high=100)
82+
83+
# Create figure
84+
p = figure(
85+
width=4800,
86+
height=2700,
87+
x_range=period_labels,
88+
y_range=list(reversed(y_labels)),
89+
title="heatmap-cohort-retention · bokeh · pyplots.ai",
90+
x_axis_location="above",
91+
toolbar_location=None,
92+
)
93+
94+
# Add heatmap rectangles
95+
rects = p.rect(
96+
x="x",
97+
y="y",
98+
width=1,
99+
height=1,
100+
source=source,
101+
fill_color=transform("value", mapper),
102+
line_color="#f0f0f0",
103+
line_width=2,
104+
)
105+
106+
# Add HoverTool for interactive exploration (distinctive Bokeh feature)
107+
hover = HoverTool(renderers=[rects], tooltips=[("Cohort", "@y"), ("Period", "@x"), ("Retention", "@text")])
108+
p.add_tools(hover)
109+
110+
# Add retention percentage text
111+
p.text(
112+
x="x",
113+
y="y",
114+
text="text",
115+
source=source,
116+
text_align="center",
117+
text_baseline="middle",
118+
text_font_size="17pt",
119+
text_color="text_color",
120+
text_font_style="bold",
121+
)
122+
123+
# Style: refined typography and spacing
124+
p.title.text_font_size = "30pt"
125+
p.title.align = "center"
126+
p.title.text_color = "#2d2d2d"
127+
p.xaxis.axis_label = "Months Since Signup"
128+
p.yaxis.axis_label = "Signup Cohort"
129+
p.xaxis.axis_label_text_font_size = "22pt"
130+
p.yaxis.axis_label_text_font_size = "22pt"
131+
p.xaxis.axis_label_text_font_style = "bold"
132+
p.yaxis.axis_label_text_font_style = "bold"
133+
p.xaxis.axis_label_text_color = "#3a3a3a"
134+
p.yaxis.axis_label_text_color = "#3a3a3a"
135+
p.xaxis.major_label_text_font_size = "18pt"
136+
p.yaxis.major_label_text_font_size = "18pt"
137+
p.xaxis.major_label_text_color = "#4a4a4a"
138+
p.yaxis.major_label_text_color = "#4a4a4a"
139+
p.axis.axis_line_color = None
140+
p.axis.major_tick_line_color = None
141+
p.axis.minor_tick_line_color = None
142+
p.grid.grid_line_color = None
143+
p.background_fill_color = "#fafafa"
144+
p.border_fill_color = "white"
145+
p.outline_line_color = None
146+
p.min_border_left = 80
147+
p.min_border_right = 120
148+
p.min_border_top = 60
149+
p.min_border_bottom = 40
150+
151+
# Storytelling: annotate the retention drop-off insight
152+
# Find the best and worst performing cohorts at Month 3
153+
month3_retentions = {y_labels[i]: retention[i, 3] for i in range(n_cohorts) if not np.isnan(retention[i, 3])}
154+
best_cohort = max(month3_retentions, key=month3_retentions.get)
155+
worst_cohort = min(month3_retentions, key=month3_retentions.get)
156+
157+
insight_label = Label(
158+
x=30,
159+
y=30,
160+
x_units="screen",
161+
y_units="screen",
162+
text=(
163+
f"Month 3 retention ranges from {month3_retentions[worst_cohort]:.0f}% "
164+
f"to {month3_retentions[best_cohort]:.0f}% across cohorts"
165+
),
166+
text_font_size="16pt",
167+
text_color="#666666",
168+
text_font_style="italic",
169+
)
170+
p.add_layout(insight_label)
171+
172+
# Add colorbar with improved spacing
173+
color_bar = ColorBar(
174+
color_mapper=mapper,
175+
ticker=BasicTicker(desired_num_ticks=6),
176+
label_standoff=16,
177+
major_label_text_font_size="18pt",
178+
major_label_text_color="#4a4a4a",
179+
title="Retention %",
180+
title_text_font_size="20pt",
181+
title_text_font_style="bold",
182+
title_standoff=16,
183+
width=45,
184+
location=(0, 0),
185+
bar_line_color=None,
186+
border_line_color=None,
187+
background_fill_color="white",
188+
)
189+
p.add_layout(color_bar, "right")
190+
191+
# Save
192+
export_png(p, filename="plot.png")

0 commit comments

Comments
 (0)