Skip to content

Commit 575aac3

Browse files
update(raincloud-basic): altair — fix orientation consistency (#4227)
## Summary Updated **altair** implementation for **raincloud-basic**. **Changes:** Fix cloud/rain orientation so cloud (half-violin/KDE) extends UPWARD and rain (jittered points) falls DOWNWARD from each category line. ### Changes - Fixed orientation: cloud now extends upward (positive y-direction), rain falls downward - Updated spec to use absolute directional terms instead of ambiguous "TOP"/"BELOW" - Preserved existing review strengths ## Test Plan - [x] Preview images uploaded to GCS staging - [x] Implementation file passes ruff format/check - [x] Metadata YAML updated with current versions - [ ] Automated review triggered Fixes #3745 --- Generated with [Claude Code](https://claude.com/claude-code) `/update` command --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 133fc94 commit 575aac3

2 files changed

Lines changed: 338 additions & 186 deletions

File tree

Lines changed: 190 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
""" pyplots.ai
22
raincloud-basic: Basic Raincloud Plot
3-
Library: altair 6.0.0 | Python 3.13.11
4-
Quality: 90/100 | Created: 2025-12-25
3+
Library: altair 6.0.0 | Python 3.14
4+
Quality: 94/100 | Created: 2025-12-25
55
"""
66

77
import altair as alt
@@ -12,15 +12,9 @@
1212
# Data: Reaction times (ms) for different treatment conditions
1313
np.random.seed(42)
1414

15-
# Create realistic reaction time data with different distributions
1615
control = np.random.normal(450, 60, 80)
17-
treatment_a = np.random.normal(380, 50, 80) # Faster responses
18-
treatment_b = np.concatenate(
19-
[
20-
np.random.normal(350, 30, 50), # Bimodal distribution
21-
np.random.normal(450, 40, 30),
22-
]
23-
)
16+
treatment_a = np.random.normal(380, 50, 80)
17+
treatment_b = np.concatenate([np.random.normal(340, 25, 50), np.random.normal(460, 35, 30)])
2418

2519
data = pd.DataFrame(
2620
{
@@ -29,43 +23,89 @@
2923
}
3024
)
3125

32-
# Map conditions to numeric positions - HORIZONTAL: y=categories, x=values
26+
# Map conditions to numeric y positions with spacing
3327
condition_order = ["Control", "Treatment A", "Treatment B"]
34-
condition_map = {c: i for i, c in enumerate(condition_order)}
28+
condition_map = {c: i * 1.5 for i, c in enumerate(condition_order)}
3529
data["condition_num"] = data["condition"].map(condition_map)
3630

37-
# Create jittered y positions for strip plot (rain BELOW the cloud)
38-
np.random.seed(42)
39-
data["jitter"] = np.random.uniform(-0.35, -0.15, len(data))
40-
data["jitter_pos"] = data["condition_num"] + data["jitter"]
31+
# Jitter positions for rain — BELOW baseline
32+
data["jitter_pos"] = data["condition_num"] + np.random.uniform(-0.30, -0.08, len(data))
33+
34+
# Box plot statistics per condition (streamlined)
35+
box_rows = []
36+
for cond in condition_order:
37+
vals = data.loc[data["condition"] == cond, "reaction_time"]
38+
q1, med, q3 = vals.quantile([0.25, 0.5, 0.75])
39+
iqr = q3 - q1
40+
box_rows.append(
41+
{
42+
"condition": cond,
43+
"condition_num": condition_map[cond],
44+
"q1": q1,
45+
"median": med,
46+
"q3": q3,
47+
"lower_w": max(q1 - 1.5 * iqr, vals.min()),
48+
"upper_w": min(q3 + 1.5 * iqr, vals.max()),
49+
}
50+
)
51+
box_df = pd.DataFrame(box_rows)
52+
53+
# Color palette: Python Blue, dark gold (high contrast), fresh green
54+
colors = ["#306998", "#D4A017", "#4CAF50"]
4155

42-
# Half-violin (cloud) - positioned on TOP (positive y offset from center)
56+
# Tighten x domain to actual data range
57+
x_min = data["reaction_time"].min()
58+
x_max = data["reaction_time"].max()
59+
x_pad = (x_max - x_min) * 0.06
60+
x_scale = alt.Scale(domain=[round(x_min - x_pad, -1), round(x_max + x_pad, -1)])
61+
y_domain = [-0.5, 3.9]
62+
x_axis = alt.Axis(
63+
titleFontSize=22,
64+
titleFontWeight="bold",
65+
titleColor="#333333",
66+
labelFontSize=18,
67+
labelColor="#555555",
68+
grid=True,
69+
gridOpacity=0.25,
70+
gridColor="#d0d0d0",
71+
gridDash=[3, 3],
72+
domainColor="#999999",
73+
tickColor="#999999",
74+
tickCount=8,
75+
)
76+
77+
# Half-violin cloud — extends ABOVE baseline
4378
violin = (
4479
alt.Chart(data)
4580
.transform_density(
46-
"reaction_time", as_=["reaction_time", "density"], groupby=["condition", "condition_num"], extent=[200, 600]
47-
)
48-
.transform_calculate(
49-
# Scale density and offset to create half-violin on TOP (positive y direction)
50-
violin_pos="datum.condition_num + 0.05 + datum.density * 180"
81+
"reaction_time",
82+
as_=["reaction_time", "density"],
83+
groupby=["condition", "condition_num"],
84+
extent=[round(x_min - x_pad, -1), round(x_max + x_pad, -1)],
5185
)
52-
.mark_area(orient="vertical", opacity=0.7)
86+
.transform_calculate(violin_pos="datum.condition_num + 0.04 + datum.density * 105")
87+
.mark_area(orient="vertical", opacity=0.55, interpolate="monotone")
5388
.encode(
54-
x=alt.X("reaction_time:Q"),
55-
y=alt.Y("condition_num:Q", axis=None).scale(domain=[-0.6, 2.6]),
89+
x=alt.X("reaction_time:Q", title="Reaction Time (ms)", scale=x_scale, axis=x_axis),
90+
y=alt.Y("condition_num:Q", axis=None, scale=alt.Scale(domain=y_domain)),
5691
y2="violin_pos:Q",
5792
color=alt.Color(
5893
"condition:N",
59-
scale=alt.Scale(domain=condition_order, range=["#306998", "#FFD43B", "#4CAF50"]),
94+
scale=alt.Scale(domain=condition_order, range=colors),
6095
legend=alt.Legend(
6196
title="Condition",
6297
titleFontSize=20,
98+
titleFontWeight="bold",
6399
labelFontSize=18,
64-
orient="right",
100+
orient="none",
101+
legendX=1350,
102+
legendY=50,
65103
fillColor="white",
66104
strokeColor="#cccccc",
67-
padding=12,
68-
cornerRadius=4,
105+
padding=14,
106+
cornerRadius=6,
107+
symbolSize=200,
108+
direction="vertical",
69109
),
70110
),
71111
tooltip=[
@@ -75,66 +115,151 @@
75115
)
76116
)
77117

78-
# Box plot - HORIZONTAL orientation (values on x, categories on y)
79-
boxplot = (
80-
alt.Chart(data)
81-
.transform_calculate(box_pos="datum.condition_num + 0.02")
82-
.mark_boxplot(
83-
size=30,
84-
orient="horizontal",
85-
median={"color": "white", "strokeWidth": 3},
86-
box={"strokeWidth": 2},
87-
outliers={"opacity": 0}, # Hide outliers, shown as jittered points
118+
# Box plot — WHITE fill with dark outline to distinguish from cloud
119+
box_iqr = (
120+
alt.Chart(box_df)
121+
.mark_bar(height=18, stroke="#333333", strokeWidth=2, cornerRadius=2, fill="white", fillOpacity=0.85)
122+
.encode(
123+
x=alt.X("q1:Q", scale=x_scale),
124+
x2="q3:Q",
125+
y=alt.Y("condition_num:Q", axis=None, scale=alt.Scale(domain=y_domain)),
88126
)
127+
)
128+
129+
# Median line — contrasting red
130+
box_median = (
131+
alt.Chart(box_df)
132+
.mark_tick(thickness=3.5, color="#E8413C", orient="vertical")
133+
.encode(x=alt.X("median:Q", scale=x_scale), y=alt.Y("condition_num:Q", axis=None, scale=alt.Scale(domain=y_domain)))
134+
)
135+
136+
# Whisker lines
137+
box_whiskers = (
138+
alt.Chart(box_df)
139+
.mark_rule(strokeWidth=1.5, color="#555555")
89140
.encode(
90-
x=alt.X("reaction_time:Q", title="Reaction Time (ms)", scale=alt.Scale(domain=[200, 600])),
91-
y=alt.Y("box_pos:Q", axis=None),
92-
color=alt.Color(
93-
"condition:N", scale=alt.Scale(domain=condition_order, range=["#306998", "#FFD43B", "#4CAF50"])
94-
),
141+
x=alt.X("lower_w:Q", scale=x_scale),
142+
x2="upper_w:Q",
143+
y=alt.Y("condition_num:Q", axis=None, scale=alt.Scale(domain=y_domain)),
95144
)
96145
)
97146

98-
# Jittered strip plot (rain) - positioned clearly BELOW the center
147+
# Whisker caps
148+
whisker_cap_data = pd.concat(
149+
[
150+
box_df[["condition_num", "lower_w"]].rename(columns={"lower_w": "x"}),
151+
box_df[["condition_num", "upper_w"]].rename(columns={"upper_w": "x"}),
152+
]
153+
)
154+
whisker_cap_data["y1"] = whisker_cap_data["condition_num"] - 0.06
155+
whisker_cap_data["y2"] = whisker_cap_data["condition_num"] + 0.06
156+
whisker_caps = (
157+
alt.Chart(whisker_cap_data)
158+
.mark_rule(strokeWidth=1.5, color="#555555")
159+
.encode(x=alt.X("x:Q", scale=x_scale), y=alt.Y("y1:Q", axis=None, scale=alt.Scale(domain=y_domain)), y2="y2:Q")
160+
)
161+
162+
# Jittered strip — rain BELOW baseline
99163
strip = (
100164
alt.Chart(data)
101-
.mark_circle(size=40, opacity=0.6)
165+
.mark_circle(size=45, opacity=0.5, stroke="#444444", strokeWidth=0.4)
102166
.encode(
103-
x=alt.X("reaction_time:Q"),
167+
x=alt.X("reaction_time:Q", scale=x_scale),
104168
y=alt.Y("jitter_pos:Q", axis=None),
105-
color=alt.Color(
106-
"condition:N", scale=alt.Scale(domain=condition_order, range=["#306998", "#FFD43B", "#4CAF50"])
107-
),
169+
color=alt.Color("condition:N", scale=alt.Scale(domain=condition_order, range=colors)),
108170
tooltip=[
109171
alt.Tooltip("condition:N", title="Condition"),
110172
alt.Tooltip("reaction_time:Q", title="Reaction Time (ms)", format=".1f"),
111173
],
112174
)
113175
)
114176

115-
# Main chart layer with raincloud elements
116-
main_chart = (
117-
alt.layer(violin, boxplot, strip).properties(width=1600, height=850).interactive() # Enable zoom and pan
177+
# Annotation: highlight Treatment B bimodality
178+
annotation_data = pd.DataFrame(
179+
[{"x": 340, "y": 3.0 + 0.65, "text": "Peak 1"}, {"x": 460, "y": 3.0 + 0.50, "text": "Peak 2"}]
180+
)
181+
bimodal_labels = (
182+
alt.Chart(annotation_data)
183+
.mark_text(fontSize=16, fontStyle="italic", color="#444444", fontWeight="bold")
184+
.encode(x=alt.X("x:Q"), y=alt.Y("y:Q", axis=None), text="text:N")
185+
)
186+
187+
arrow_data = pd.DataFrame([{"x": 355, "y": 3.0 + 0.58, "x2": 445, "y2": 3.0 + 0.53}])
188+
bimodal_arrow = (
189+
alt.Chart(arrow_data)
190+
.mark_rule(strokeDash=[4, 3], color="#777777", strokeWidth=1.5)
191+
.encode(x=alt.X("x:Q"), y=alt.Y("y:Q", axis=None), x2="x2:Q", y2="y2:Q")
118192
)
119193

120-
# Y-axis labels as a separate chart on the left
121-
y_axis_data = pd.DataFrame({"condition": condition_order, "y_pos": [0, 1, 2]})
194+
note_data = pd.DataFrame([{"x": 400, "y": 3.0 + 0.78, "text": "Bimodal distribution"}])
195+
bimodal_note = (
196+
alt.Chart(note_data)
197+
.mark_text(fontSize=15, color="#555555", fontStyle="italic")
198+
.encode(x=alt.X("x:Q"), y=alt.Y("y:Q", axis=None), text="text:N")
199+
)
200+
201+
# Median value annotations on box plots
202+
median_labels = (
203+
alt.Chart(box_df)
204+
.transform_calculate(label_y="datum.condition_num + 0.18")
205+
.mark_text(fontSize=14, color="#E8413C", fontWeight="bold", dy=-14)
206+
.encode(
207+
x=alt.X("median:Q", scale=x_scale),
208+
y=alt.Y("condition_num:Q", axis=None, scale=alt.Scale(domain=y_domain)),
209+
text=alt.Text("median:Q", format=".0f"),
210+
)
211+
)
122212

123-
y_axis_labels = (
124-
alt.Chart(y_axis_data)
125-
.mark_text(fontSize=20, fontWeight="bold", align="right", baseline="middle")
126-
.encode(y=alt.Y("y_pos:Q", scale=alt.Scale(domain=[-0.6, 2.6]), axis=None), text="condition:N")
127-
.properties(width=120, height=850)
213+
# Y-axis tick labels
214+
y_label_data = pd.DataFrame({"condition": condition_order, "y_pos": [0, 1.5, 3.0]})
215+
y_labels = (
216+
alt.Chart(y_label_data)
217+
.transform_calculate(x=str(round(x_min - x_pad, -1)))
218+
.mark_text(fontSize=20, fontWeight="bold", align="right", baseline="middle", dx=-25, clip=False, color="#333333")
219+
.encode(x=alt.X("x:Q"), y=alt.Y("y_pos:Q", axis=None, scale=alt.Scale(domain=y_domain)), text="condition:N")
128220
)
129221

130-
# Combine using horizontal concatenation
222+
# Compose all layers
131223
chart = (
132-
alt.hconcat(y_axis_labels, main_chart, spacing=5)
133-
.properties(title=alt.Title("raincloud-basic · altair · pyplots.ai", fontSize=28, anchor="middle"))
134-
.configure_axis(labelFontSize=18, titleFontSize=22, gridOpacity=0.3)
224+
alt.layer(
225+
violin,
226+
box_whiskers,
227+
whisker_caps,
228+
box_iqr,
229+
box_median,
230+
median_labels,
231+
strip,
232+
bimodal_labels,
233+
bimodal_arrow,
234+
bimodal_note,
235+
y_labels,
236+
)
237+
.resolve_scale(y="shared", x="shared")
238+
.properties(
239+
width=1600,
240+
height=900,
241+
title=alt.Title(
242+
"raincloud-basic · altair · pyplots.ai",
243+
fontSize=28,
244+
anchor="middle",
245+
offset=20,
246+
color="#333333",
247+
fontWeight="bold",
248+
),
249+
)
250+
.configure(padding={"left": 140, "right": 20, "top": 10, "bottom": 40})
251+
.configure_axis(
252+
labelFontSize=18,
253+
titleFontSize=22,
254+
gridColor="#d0d0d0",
255+
gridOpacity=0.25,
256+
domainColor="#999999",
257+
tickColor="#999999",
258+
)
135259
.configure_view(strokeWidth=0)
260+
.interactive()
136261
)
137262

138-
# Save outputs
263+
# Save
139264
chart.save("plot.png", scale_factor=3.0)
140265
chart.save("plot.html")

0 commit comments

Comments
 (0)