Skip to content

Commit 402d7dc

Browse files
MarkusNeusingerclaude[bot]github-actions[bot]
authored
feat: add scatter-basic implementation (9 libraries) (#510)
## Summary Adds `scatter-basic` plot implementation. ### Libraries - **Merged:** 9 (all libraries) - **Not Feasible:** 0 ### Links - **Spec:** `specs/scatter-basic.md` - **Parent Issue:** #207 --- :robot: *Auto-generated by pyplots CI* --------- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
1 parent 7066653 commit 402d7dc

File tree

9 files changed

+351
-231
lines changed

9 files changed

+351
-231
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
scatter-basic: Basic Scatter Plot
3+
Library: altair
4+
"""
5+
6+
import altair as alt
7+
import numpy as np
8+
import pandas as pd
9+
10+
11+
# Data
12+
np.random.seed(42)
13+
x = np.random.randn(100) * 2 + 10
14+
y = x * 0.8 + np.random.randn(100) * 2
15+
16+
data = pd.DataFrame({"x": x, "y": y})
17+
18+
# Create chart
19+
chart = (
20+
alt.Chart(data)
21+
.mark_point(filled=True, opacity=0.7, size=100, color="#306998")
22+
.encode(x=alt.X("x:Q", title="X Value"), y=alt.Y("y:Q", title="Y Value"), tooltip=["x:Q", "y:Q"])
23+
.properties(width=1600, height=900, title="Basic Scatter Plot")
24+
.configure_axis(labelFontSize=16, titleFontSize=20)
25+
.configure_title(fontSize=20)
26+
)
27+
28+
# Save as PNG (1600 * 3 = 4800px, 900 * 3 = 2700px)
29+
chart.save("plot.png", scale_factor=3.0)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
scatter-basic: Basic Scatter Plot
3+
Library: bokeh
4+
"""
5+
6+
import numpy as np
7+
from bokeh.io import export_png
8+
from bokeh.models import ColumnDataSource
9+
from bokeh.plotting import figure
10+
11+
12+
# Data
13+
np.random.seed(42)
14+
x = np.random.randn(100) * 2 + 10
15+
y = x * 0.8 + np.random.randn(100) * 2
16+
17+
source = ColumnDataSource(data={"x": x, "y": y})
18+
19+
# Create figure (4800 × 2700 px for 16:9 aspect ratio)
20+
p = figure(width=4800, height=2700, title="Basic Scatter Plot", x_axis_label="X Value", y_axis_label="Y Value")
21+
22+
# Plot scatter
23+
p.scatter(x="x", y="y", source=source, size=12, color="#306998", alpha=0.7)
24+
25+
# Styling
26+
p.title.text_font_size = "20pt"
27+
p.xaxis.axis_label_text_font_size = "20pt"
28+
p.yaxis.axis_label_text_font_size = "20pt"
29+
p.xaxis.major_label_text_font_size = "16pt"
30+
p.yaxis.major_label_text_font_size = "16pt"
31+
p.grid.grid_line_alpha = 0.3
32+
33+
# Save
34+
export_png(p, filename="plot.png")
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
scatter-basic: Basic Scatter Plot
3+
Library: highcharts
4+
"""
5+
6+
import tempfile
7+
import time
8+
import urllib.request
9+
from pathlib import Path
10+
11+
from highcharts_core.chart import Chart
12+
from highcharts_core.options import HighchartsOptions
13+
from highcharts_core.options.series.scatter import ScatterSeries
14+
from selenium import webdriver
15+
from selenium.webdriver.chrome.options import Options
16+
17+
18+
# Data
19+
x = [1, 2, 3, 4, 5, 6, 7, 8]
20+
y = [2.1, 4.3, 3.2, 5.8, 4.9, 7.2, 6.1, 8.5]
21+
22+
# Create chart with container
23+
chart = Chart(container="container")
24+
chart.options = HighchartsOptions()
25+
26+
# Chart configuration
27+
chart.options.chart = {"type": "scatter", "width": 4800, "height": 2700, "backgroundColor": "#ffffff"}
28+
29+
# Title
30+
chart.options.title = {"text": "Basic Scatter Plot", "style": {"fontSize": "60px"}}
31+
32+
# Axes
33+
chart.options.x_axis = {
34+
"title": {"text": "X Value", "style": {"fontSize": "48px"}},
35+
"labels": {"style": {"fontSize": "40px"}},
36+
"gridLineWidth": 1,
37+
"gridLineColor": "rgba(0, 0, 0, 0.1)",
38+
}
39+
chart.options.y_axis = {
40+
"title": {"text": "Y Value", "style": {"fontSize": "48px"}},
41+
"labels": {"style": {"fontSize": "40px"}},
42+
"gridLineWidth": 1,
43+
"gridLineColor": "rgba(0, 0, 0, 0.1)",
44+
}
45+
46+
# Legend (not needed for single series, but kept minimal)
47+
chart.options.legend = {"enabled": False}
48+
49+
# Add series
50+
series = ScatterSeries()
51+
series.data = list(zip(x, y, strict=False))
52+
series.name = "Data"
53+
series.marker = {"radius": 20, "fillColor": "#306998", "lineWidth": 2, "lineColor": "#306998"}
54+
chart.add_series(series)
55+
56+
# Download Highcharts JS for inline embedding
57+
highcharts_url = "https://code.highcharts.com/highcharts.js"
58+
with urllib.request.urlopen(highcharts_url, timeout=30) as response:
59+
highcharts_js = response.read().decode("utf-8")
60+
61+
# Generate HTML with inline scripts
62+
html_str = chart.to_js_literal()
63+
html_content = f"""<!DOCTYPE html>
64+
<html>
65+
<head>
66+
<meta charset="utf-8">
67+
<script>{highcharts_js}</script>
68+
</head>
69+
<body style="margin:0;">
70+
<div id="container" style="width: 4800px; height: 2700px;"></div>
71+
<script>{html_str}</script>
72+
</body>
73+
</html>"""
74+
75+
# Write temp HTML and take screenshot
76+
with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False, encoding="utf-8") as f:
77+
f.write(html_content)
78+
temp_path = f.name
79+
80+
chrome_options = Options()
81+
chrome_options.add_argument("--headless")
82+
chrome_options.add_argument("--no-sandbox")
83+
chrome_options.add_argument("--disable-dev-shm-usage")
84+
chrome_options.add_argument("--disable-gpu")
85+
chrome_options.add_argument("--window-size=5000,3000")
86+
87+
driver = webdriver.Chrome(options=chrome_options)
88+
driver.get(f"file://{temp_path}")
89+
time.sleep(5)
90+
91+
# Screenshot the chart container element for exact dimensions
92+
container = driver.find_element("id", "container")
93+
container.screenshot("plot.png")
94+
driver.quit()
95+
96+
Path(temp_path).unlink()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""
2+
scatter-basic: Basic Scatter Plot
3+
Library: letsplot
4+
"""
5+
6+
import numpy as np
7+
import pandas as pd
8+
from lets_plot import LetsPlot, aes, element_text, geom_point, ggplot, ggsave, ggsize, labs, theme, theme_minimal
9+
10+
11+
LetsPlot.setup_html()
12+
13+
# Data
14+
np.random.seed(42)
15+
x = np.random.randn(100) * 2 + 10
16+
y = x * 0.8 + np.random.randn(100) * 2
17+
18+
data = pd.DataFrame({"x": x, "y": y})
19+
20+
# Plot
21+
plot = (
22+
ggplot(data, aes(x="x", y="y"))
23+
+ geom_point(color="#306998", size=4, alpha=0.7)
24+
+ labs(x="X Value", y="Y Value", title="Basic Scatter Plot")
25+
+ ggsize(1600, 900)
26+
+ theme_minimal()
27+
+ theme(plot_title=element_text(size=20), axis_title=element_text(size=20), axis_text=element_text(size=16))
28+
)
29+
30+
# Save (scale 3x to get 4800 × 2700 px)
31+
ggsave(plot, "plot.png", path=".", scale=3)
Lines changed: 16 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,27 @@
11
"""
22
scatter-basic: Basic Scatter Plot
3-
Implementation for: matplotlib
4-
Variant: default
5-
Python: 3.10+
3+
Library: matplotlib
64
"""
75

8-
from typing import TYPE_CHECKING, Optional
9-
106
import matplotlib.pyplot as plt
117
import numpy as np
12-
import pandas as pd
13-
14-
15-
if TYPE_CHECKING:
16-
from matplotlib.figure import Figure
17-
18-
19-
def create_plot(
20-
data: pd.DataFrame,
21-
x: str,
22-
y: str,
23-
figsize: tuple[float, float] = (16, 9),
24-
alpha: float = 0.6,
25-
size: float = 30,
26-
color: str = "steelblue",
27-
title: Optional[str] = None,
28-
xlabel: Optional[str] = None,
29-
ylabel: Optional[str] = None,
30-
edgecolors: Optional[str] = None,
31-
linewidth: float = 0,
32-
**kwargs,
33-
) -> "Figure":
34-
"""
35-
Create a basic scatter plot visualizing the relationship between two continuous variables.
36-
37-
Args:
38-
data: Input DataFrame with required columns
39-
x: Column name for x-axis values
40-
y: Column name for y-axis values
41-
figsize: Figure size as (width, height) tuple (default: (16, 9))
42-
alpha: Transparency level for points (default: 0.6 for better visibility with many points)
43-
size: Point size (default: 30)
44-
color: Point color (default: "steelblue")
45-
title: Plot title (default: None)
46-
xlabel: X-axis label (default: uses column name)
47-
ylabel: Y-axis label (default: uses column name)
48-
edgecolors: Edge color for points (default: None)
49-
linewidth: Width of edge lines (default: 0)
50-
**kwargs: Additional parameters passed to scatter function
51-
52-
Returns:
53-
Matplotlib Figure object
54-
55-
Raises:
56-
ValueError: If data is empty
57-
KeyError: If required columns not found
58-
TypeError: If x or y columns contain non-numeric data
59-
60-
Example:
61-
>>> data = pd.DataFrame({'x': [1, 2, 3], 'y': [2, 4, 3]})
62-
>>> fig = create_plot(data, 'x', 'y')
63-
"""
64-
# Input validation
65-
if data.empty:
66-
raise ValueError("Data cannot be empty")
67-
68-
# Check required columns
69-
for col in [x, y]:
70-
if col not in data.columns:
71-
available = ", ".join(data.columns)
72-
raise KeyError(f"Column '{col}' not found. Available columns: {available}")
73-
74-
# Check if columns are numeric
75-
if not pd.api.types.is_numeric_dtype(data[x]):
76-
raise TypeError(f"Column '{x}' must contain numeric data")
77-
if not pd.api.types.is_numeric_dtype(data[y]):
78-
raise TypeError(f"Column '{y}' must contain numeric data")
79-
80-
# Create figure
81-
fig, ax = plt.subplots(figsize=figsize)
82-
83-
# Plot data
84-
ax.scatter(data[x], data[y], s=size, alpha=alpha, c=color, edgecolors=edgecolors, linewidth=linewidth, **kwargs)
85-
86-
# Labels and title
87-
ax.set_xlabel(xlabel or x)
88-
ax.set_ylabel(ylabel or y)
89-
90-
if title:
91-
ax.set_title(title)
92-
93-
# Apply styling
94-
ax.grid(True, alpha=0.3)
95-
96-
# Layout
97-
plt.tight_layout()
98-
99-
return fig
1008

1019

102-
if __name__ == "__main__":
103-
# Sample data for testing - many points to demonstrate basic scatter
104-
np.random.seed(42)
105-
n_points = 500
10+
# Data
11+
np.random.seed(42)
12+
x = np.random.randn(100) * 2 + 5
13+
y = x * 0.8 + np.random.randn(100) * 1.5
10614

107-
data = pd.DataFrame(
108-
{
109-
"x": np.random.randn(n_points) * 2 + 10,
110-
"y": np.random.randn(n_points) * 3 + 15 + np.random.randn(n_points) * 0.5,
111-
}
112-
)
15+
# Create plot
16+
fig, ax = plt.subplots(figsize=(16, 9))
17+
ax.scatter(x, y, alpha=0.7, s=80, color="#306998")
11318

114-
# Create plot
115-
fig = create_plot(data, "x", "y", title="Basic Scatter Plot Example", xlabel="X Value", ylabel="Y Value")
19+
# Labels and styling
20+
ax.set_xlabel("X Value", fontsize=20)
21+
ax.set_ylabel("Y Value", fontsize=20)
22+
ax.set_title("Basic Scatter Plot", fontsize=20)
23+
ax.tick_params(axis="both", labelsize=16)
24+
ax.grid(True, alpha=0.3)
11625

117-
# Save for inspection - ALWAYS use 'plot.png' as filename
118-
plt.savefig("plot.png", dpi=300, bbox_inches="tight")
119-
print("Plot saved to plot.png")
26+
plt.tight_layout()
27+
plt.savefig("plot.png", dpi=300, bbox_inches="tight")
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
scatter-basic: Basic Scatter Plot
3+
Library: plotly
4+
"""
5+
6+
import numpy as np
7+
import plotly.graph_objects as go
8+
9+
10+
# Data
11+
np.random.seed(42)
12+
x = np.random.randn(100) * 2 + 10
13+
y = x * 0.8 + np.random.randn(100) * 2
14+
15+
# Create figure
16+
fig = go.Figure()
17+
18+
fig.add_trace(go.Scatter(x=x, y=y, mode="markers", marker={"size": 12, "color": "#306998", "opacity": 0.7}))
19+
20+
# Layout
21+
fig.update_layout(
22+
title={"text": "Basic Scatter Plot", "font": {"size": 40}, "x": 0.5, "xanchor": "center"},
23+
xaxis={
24+
"title": {"text": "X Value", "font": {"size": 40}},
25+
"tickfont": {"size": 32},
26+
"showgrid": True,
27+
"gridcolor": "rgba(0, 0, 0, 0.1)",
28+
},
29+
yaxis={
30+
"title": {"text": "Y Value", "font": {"size": 40}},
31+
"tickfont": {"size": 32},
32+
"showgrid": True,
33+
"gridcolor": "rgba(0, 0, 0, 0.1)",
34+
},
35+
template="plotly_white",
36+
plot_bgcolor="white",
37+
margin={"l": 120, "r": 50, "t": 100, "b": 100},
38+
)
39+
40+
# Save (4800 x 2700 px)
41+
fig.write_image("plot.png", width=1600, height=900, scale=3)

0 commit comments

Comments
 (0)