|
| 1 | +""" |
| 2 | +area-basic: Basic Area Chart |
| 3 | +Implementation for: plotnine |
| 4 | +Variant: default |
| 5 | +Python: 3.10+ |
| 6 | +""" |
| 7 | + |
| 8 | +from typing import TYPE_CHECKING, Optional |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import pandas as pd |
| 12 | +from plotnine import aes, element_line, element_text, geom_area, geom_line, ggplot, labs, theme, theme_minimal |
| 13 | + |
| 14 | + |
| 15 | +if TYPE_CHECKING: |
| 16 | + from plotnine import ggplot as GGPlot |
| 17 | + |
| 18 | + |
| 19 | +def create_plot( |
| 20 | + data: pd.DataFrame, |
| 21 | + x: str, |
| 22 | + y: str, |
| 23 | + title: Optional[str] = None, |
| 24 | + xlabel: Optional[str] = None, |
| 25 | + ylabel: Optional[str] = None, |
| 26 | + color: str = "steelblue", |
| 27 | + alpha: float = 0.6, |
| 28 | + line_color: Optional[str] = None, |
| 29 | + line_width: float = 1.5, |
| 30 | + width: int = 16, |
| 31 | + height: int = 9, |
| 32 | + **kwargs, |
| 33 | +) -> "GGPlot": |
| 34 | + """ |
| 35 | + Create a basic area chart showing values over a continuous axis using plotnine (ggplot2 syntax). |
| 36 | +
|
| 37 | + Args: |
| 38 | + data: Input DataFrame with required columns |
| 39 | + x: Column name for x-axis values (numeric or datetime) |
| 40 | + y: Column name for y-axis values (numeric) |
| 41 | + title: Plot title (optional) |
| 42 | + xlabel: Custom x-axis label (optional, defaults to x column name) |
| 43 | + ylabel: Custom y-axis label (optional, defaults to y column name) |
| 44 | + color: Fill color for the area (default: 'steelblue') |
| 45 | + alpha: Fill transparency level (default: 0.6) |
| 46 | + line_color: Color of the top line (default: same as color) |
| 47 | + line_width: Width of the top line (default: 1.5) |
| 48 | + width: Figure width in inches (default: 16) |
| 49 | + height: Figure height in inches (default: 9) |
| 50 | + **kwargs: Additional parameters for geom_area |
| 51 | +
|
| 52 | + Returns: |
| 53 | + plotnine ggplot object |
| 54 | +
|
| 55 | + Raises: |
| 56 | + ValueError: If data is empty |
| 57 | + KeyError: If required columns not found |
| 58 | +
|
| 59 | + Example: |
| 60 | + >>> data = pd.DataFrame({ |
| 61 | + ... 'time': [1, 2, 3, 4, 5], |
| 62 | + ... 'value': [10, 25, 15, 30, 20] |
| 63 | + ... }) |
| 64 | + >>> plot = create_plot(data, x='time', y='value') |
| 65 | + """ |
| 66 | + # Input validation |
| 67 | + if data.empty: |
| 68 | + raise ValueError("Data cannot be empty") |
| 69 | + |
| 70 | + # Check required columns |
| 71 | + for col in [x, y]: |
| 72 | + if col not in data.columns: |
| 73 | + available = ", ".join(data.columns) |
| 74 | + raise KeyError(f"Column '{col}' not found. Available columns: {available}") |
| 75 | + |
| 76 | + # Use the same color for line if not specified |
| 77 | + if line_color is None: |
| 78 | + line_color = color |
| 79 | + |
| 80 | + # Sort data by x to ensure proper area rendering |
| 81 | + data_sorted = data.sort_values(by=x).copy() |
| 82 | + |
| 83 | + # Create the ggplot object with area and line |
| 84 | + plot = ( |
| 85 | + ggplot(data_sorted, aes(x=x, y=y)) |
| 86 | + + geom_area(fill=color, alpha=alpha, **kwargs) |
| 87 | + + geom_line(color=line_color, size=line_width) |
| 88 | + + labs(title=title or "Area Chart", x=xlabel or x, y=ylabel or y) |
| 89 | + + theme_minimal() |
| 90 | + + theme( |
| 91 | + figure_size=(width, height), |
| 92 | + plot_title=element_text(size=14, weight="bold", ha="center"), |
| 93 | + axis_title=element_text(size=11), |
| 94 | + axis_text=element_text(size=10), |
| 95 | + panel_grid_major=element_line(alpha=0.3, linetype="dashed"), |
| 96 | + panel_grid_minor=element_line(alpha=0), |
| 97 | + ) |
| 98 | + ) |
| 99 | + |
| 100 | + return plot |
| 101 | + |
| 102 | + |
| 103 | +if __name__ == "__main__": |
| 104 | + # Sample data for testing - simulating time series data |
| 105 | + np.random.seed(42) # For reproducibility |
| 106 | + |
| 107 | + # Generate sample time series data (e.g., monthly website visitors) |
| 108 | + months = pd.date_range(start="2024-01-01", periods=12, freq="MS") |
| 109 | + |
| 110 | + # Create realistic-looking growth pattern with some variation |
| 111 | + base_values = np.linspace(1000, 2500, 12) |
| 112 | + noise = np.random.normal(0, 150, 12) |
| 113 | + values = base_values + noise |
| 114 | + |
| 115 | + # Ensure no negative values |
| 116 | + values = np.maximum(values, 100) |
| 117 | + |
| 118 | + data = pd.DataFrame( |
| 119 | + { |
| 120 | + "Month": range(1, 13), # Use numeric for simpler plotting |
| 121 | + "Visitors": values, |
| 122 | + } |
| 123 | + ) |
| 124 | + |
| 125 | + # Create plot |
| 126 | + plot = create_plot( |
| 127 | + data, |
| 128 | + x="Month", |
| 129 | + y="Visitors", |
| 130 | + title="Monthly Website Visitors (2024)", |
| 131 | + xlabel="Month", |
| 132 | + ylabel="Number of Visitors", |
| 133 | + color="#3498db", |
| 134 | + alpha=0.5, |
| 135 | + line_color="#2980b9", |
| 136 | + line_width=2, |
| 137 | + ) |
| 138 | + |
| 139 | + # Save for inspection |
| 140 | + plot.save("plot.png", dpi=300, verbose=False) |
| 141 | + print("Plot saved to plot.png") |
0 commit comments