Skip to content

Commit 8ae9db1

Browse files
feat(matplotlib): implement scatter-basic (#316)
## Summary Implements `scatter-basic` for **matplotlib** library. **Parent Issue:** #207 **Sub-Issue:** #231 **Base Branch:** `plot/scatter-basic` **Attempt:** 1/3 ## Implementation - `plots/matplotlib/scatter/scatter-basic/default.py` ## Changes - Simplified implementation to follow KISS style guide (plot-generator.md) - Uses sequential script structure: imports → data → plot → save - No functions, classes, or type hints per guidelines - Follows default-style-guide.md for colors (Python Blue #306998) and dimensions (16:9 aspect) - Proper font sizes for 4800x2700px output Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
1 parent 27357ca commit 8ae9db1

1 file changed

Lines changed: 16 additions & 108 deletions

File tree

Lines changed: 16 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,27 @@
11
"""
22
scatter-basic: Basic Scatter Plot
3-
Implementation for: matplotlib
4-
Variant: default
5-
Python: 3.10+
3+
Library: matplotlib
64
"""
75

8-
from typing import TYPE_CHECKING, Optional
9-
106
import matplotlib.pyplot as plt
117
import numpy as np
12-
import pandas as pd
13-
14-
15-
if TYPE_CHECKING:
16-
from matplotlib.figure import Figure
17-
18-
19-
def create_plot(
20-
data: pd.DataFrame,
21-
x: str,
22-
y: str,
23-
figsize: tuple[float, float] = (16, 9),
24-
alpha: float = 0.6,
25-
size: float = 30,
26-
color: str = "steelblue",
27-
title: Optional[str] = None,
28-
xlabel: Optional[str] = None,
29-
ylabel: Optional[str] = None,
30-
edgecolors: Optional[str] = None,
31-
linewidth: float = 0,
32-
**kwargs,
33-
) -> "Figure":
34-
"""
35-
Create a basic scatter plot visualizing the relationship between two continuous variables.
36-
37-
Args:
38-
data: Input DataFrame with required columns
39-
x: Column name for x-axis values
40-
y: Column name for y-axis values
41-
figsize: Figure size as (width, height) tuple (default: (16, 9))
42-
alpha: Transparency level for points (default: 0.6 for better visibility with many points)
43-
size: Point size (default: 30)
44-
color: Point color (default: "steelblue")
45-
title: Plot title (default: None)
46-
xlabel: X-axis label (default: uses column name)
47-
ylabel: Y-axis label (default: uses column name)
48-
edgecolors: Edge color for points (default: None)
49-
linewidth: Width of edge lines (default: 0)
50-
**kwargs: Additional parameters passed to scatter function
51-
52-
Returns:
53-
Matplotlib Figure object
54-
55-
Raises:
56-
ValueError: If data is empty
57-
KeyError: If required columns not found
58-
TypeError: If x or y columns contain non-numeric data
59-
60-
Example:
61-
>>> data = pd.DataFrame({'x': [1, 2, 3], 'y': [2, 4, 3]})
62-
>>> fig = create_plot(data, 'x', 'y')
63-
"""
64-
# Input validation
65-
if data.empty:
66-
raise ValueError("Data cannot be empty")
67-
68-
# Check required columns
69-
for col in [x, y]:
70-
if col not in data.columns:
71-
available = ", ".join(data.columns)
72-
raise KeyError(f"Column '{col}' not found. Available columns: {available}")
73-
74-
# Check if columns are numeric
75-
if not pd.api.types.is_numeric_dtype(data[x]):
76-
raise TypeError(f"Column '{x}' must contain numeric data")
77-
if not pd.api.types.is_numeric_dtype(data[y]):
78-
raise TypeError(f"Column '{y}' must contain numeric data")
79-
80-
# Create figure
81-
fig, ax = plt.subplots(figsize=figsize)
82-
83-
# Plot data
84-
ax.scatter(data[x], data[y], s=size, alpha=alpha, c=color, edgecolors=edgecolors, linewidth=linewidth, **kwargs)
85-
86-
# Labels and title
87-
ax.set_xlabel(xlabel or x)
88-
ax.set_ylabel(ylabel or y)
89-
90-
if title:
91-
ax.set_title(title)
92-
93-
# Apply styling
94-
ax.grid(True, alpha=0.3)
95-
96-
# Layout
97-
plt.tight_layout()
98-
99-
return fig
1008

1019

102-
if __name__ == "__main__":
103-
# Sample data for testing - many points to demonstrate basic scatter
104-
np.random.seed(42)
105-
n_points = 500
10+
# Data
11+
np.random.seed(42)
12+
x = np.random.randn(100) * 2 + 5
13+
y = x * 0.8 + np.random.randn(100) * 1.5
10614

107-
data = pd.DataFrame(
108-
{
109-
"x": np.random.randn(n_points) * 2 + 10,
110-
"y": np.random.randn(n_points) * 3 + 15 + np.random.randn(n_points) * 0.5,
111-
}
112-
)
15+
# Create plot
16+
fig, ax = plt.subplots(figsize=(16, 9))
17+
ax.scatter(x, y, alpha=0.7, s=50, color="#306998")
11318

114-
# Create plot
115-
fig = create_plot(data, "x", "y", title="Basic Scatter Plot Example", xlabel="X Value", ylabel="Y Value")
19+
# Labels and styling
20+
ax.set_xlabel("X Value", fontsize=20)
21+
ax.set_ylabel("Y Value", fontsize=20)
22+
ax.set_title("Basic Scatter Plot", fontsize=20)
23+
ax.tick_params(axis="both", labelsize=16)
24+
ax.grid(True, alpha=0.3)
11625

117-
# Save for inspection - ALWAYS use 'plot.png' as filename
118-
plt.savefig("plot.png", dpi=300, bbox_inches="tight")
119-
print("Plot saved to plot.png")
26+
plt.tight_layout()
27+
plt.savefig("plot.png", dpi=300, bbox_inches="tight")

0 commit comments

Comments
 (0)