|
1 | 1 | """ |
2 | 2 | scatter-basic: Basic Scatter Plot |
3 | | -Implementation for: matplotlib |
4 | | -Variant: default |
5 | | -Python: 3.10+ |
| 3 | +Library: matplotlib |
6 | 4 | """ |
7 | 5 |
|
8 | | -from typing import TYPE_CHECKING, Optional |
9 | | - |
10 | 6 | import matplotlib.pyplot as plt |
11 | 7 | import numpy as np |
12 | | -import pandas as pd |
13 | | - |
14 | | - |
15 | | -if TYPE_CHECKING: |
16 | | - from matplotlib.figure import Figure |
17 | | - |
18 | | - |
19 | | -def create_plot( |
20 | | - data: pd.DataFrame, |
21 | | - x: str, |
22 | | - y: str, |
23 | | - figsize: tuple[float, float] = (16, 9), |
24 | | - alpha: float = 0.6, |
25 | | - size: float = 30, |
26 | | - color: str = "steelblue", |
27 | | - title: Optional[str] = None, |
28 | | - xlabel: Optional[str] = None, |
29 | | - ylabel: Optional[str] = None, |
30 | | - edgecolors: Optional[str] = None, |
31 | | - linewidth: float = 0, |
32 | | - **kwargs, |
33 | | -) -> "Figure": |
34 | | - """ |
35 | | - Create a basic scatter plot visualizing the relationship between two continuous variables. |
36 | | -
|
37 | | - Args: |
38 | | - data: Input DataFrame with required columns |
39 | | - x: Column name for x-axis values |
40 | | - y: Column name for y-axis values |
41 | | - figsize: Figure size as (width, height) tuple (default: (16, 9)) |
42 | | - alpha: Transparency level for points (default: 0.6 for better visibility with many points) |
43 | | - size: Point size (default: 30) |
44 | | - color: Point color (default: "steelblue") |
45 | | - title: Plot title (default: None) |
46 | | - xlabel: X-axis label (default: uses column name) |
47 | | - ylabel: Y-axis label (default: uses column name) |
48 | | - edgecolors: Edge color for points (default: None) |
49 | | - linewidth: Width of edge lines (default: 0) |
50 | | - **kwargs: Additional parameters passed to scatter function |
51 | | -
|
52 | | - Returns: |
53 | | - Matplotlib Figure object |
54 | | -
|
55 | | - Raises: |
56 | | - ValueError: If data is empty |
57 | | - KeyError: If required columns not found |
58 | | - TypeError: If x or y columns contain non-numeric data |
59 | | -
|
60 | | - Example: |
61 | | - >>> data = pd.DataFrame({'x': [1, 2, 3], 'y': [2, 4, 3]}) |
62 | | - >>> fig = create_plot(data, 'x', 'y') |
63 | | - """ |
64 | | - # Input validation |
65 | | - if data.empty: |
66 | | - raise ValueError("Data cannot be empty") |
67 | | - |
68 | | - # Check required columns |
69 | | - for col in [x, y]: |
70 | | - if col not in data.columns: |
71 | | - available = ", ".join(data.columns) |
72 | | - raise KeyError(f"Column '{col}' not found. Available columns: {available}") |
73 | | - |
74 | | - # Check if columns are numeric |
75 | | - if not pd.api.types.is_numeric_dtype(data[x]): |
76 | | - raise TypeError(f"Column '{x}' must contain numeric data") |
77 | | - if not pd.api.types.is_numeric_dtype(data[y]): |
78 | | - raise TypeError(f"Column '{y}' must contain numeric data") |
79 | | - |
80 | | - # Create figure |
81 | | - fig, ax = plt.subplots(figsize=figsize) |
82 | | - |
83 | | - # Plot data |
84 | | - ax.scatter(data[x], data[y], s=size, alpha=alpha, c=color, edgecolors=edgecolors, linewidth=linewidth, **kwargs) |
85 | | - |
86 | | - # Labels and title |
87 | | - ax.set_xlabel(xlabel or x) |
88 | | - ax.set_ylabel(ylabel or y) |
89 | | - |
90 | | - if title: |
91 | | - ax.set_title(title) |
92 | | - |
93 | | - # Apply styling |
94 | | - ax.grid(True, alpha=0.3) |
95 | | - |
96 | | - # Layout |
97 | | - plt.tight_layout() |
98 | | - |
99 | | - return fig |
100 | 8 |
|
101 | 9 |
|
102 | | -if __name__ == "__main__": |
103 | | - # Sample data for testing - many points to demonstrate basic scatter |
104 | | - np.random.seed(42) |
105 | | - n_points = 500 |
| 10 | +# Data |
| 11 | +np.random.seed(42) |
| 12 | +x = np.random.randn(100) * 2 + 5 |
| 13 | +y = x * 0.8 + np.random.randn(100) * 1.5 |
106 | 14 |
|
107 | | - data = pd.DataFrame( |
108 | | - { |
109 | | - "x": np.random.randn(n_points) * 2 + 10, |
110 | | - "y": np.random.randn(n_points) * 3 + 15 + np.random.randn(n_points) * 0.5, |
111 | | - } |
112 | | - ) |
| 15 | +# Create plot |
| 16 | +fig, ax = plt.subplots(figsize=(16, 9)) |
| 17 | +ax.scatter(x, y, alpha=0.7, s=80, color="#306998") |
113 | 18 |
|
114 | | - # Create plot |
115 | | - fig = create_plot(data, "x", "y", title="Basic Scatter Plot Example", xlabel="X Value", ylabel="Y Value") |
| 19 | +# Labels and styling |
| 20 | +ax.set_xlabel("X Value", fontsize=20) |
| 21 | +ax.set_ylabel("Y Value", fontsize=20) |
| 22 | +ax.set_title("Basic Scatter Plot", fontsize=20) |
| 23 | +ax.tick_params(axis="both", labelsize=16) |
| 24 | +ax.grid(True, alpha=0.3) |
116 | 25 |
|
117 | | - # Save for inspection - ALWAYS use 'plot.png' as filename |
118 | | - plt.savefig("plot.png", dpi=300, bbox_inches="tight") |
119 | | - print("Plot saved to plot.png") |
| 26 | +plt.tight_layout() |
| 27 | +plt.savefig("plot.png", dpi=300, bbox_inches="tight") |
0 commit comments