|
| 1 | +""" |
| 2 | +box-basic: Basic Box Plot |
| 3 | +Implementation for: altair |
| 4 | +Variant: default |
| 5 | +Python: 3.10+ |
| 6 | +""" |
| 7 | + |
| 8 | +from typing import TYPE_CHECKING, Optional |
| 9 | + |
| 10 | +import altair as alt |
| 11 | +import numpy as np |
| 12 | +import pandas as pd |
| 13 | + |
| 14 | + |
| 15 | +if TYPE_CHECKING: |
| 16 | + from altair import Chart |
| 17 | + |
| 18 | + |
| 19 | +def create_plot( |
| 20 | + data: pd.DataFrame, |
| 21 | + values: str, |
| 22 | + groups: str, |
| 23 | + title: Optional[str] = None, |
| 24 | + xlabel: Optional[str] = None, |
| 25 | + ylabel: Optional[str] = None, |
| 26 | + color_scheme: str = "set2", |
| 27 | + width: int = 800, |
| 28 | + height: int = 450, |
| 29 | + **kwargs, |
| 30 | +) -> Chart: |
| 31 | + """ |
| 32 | + Create a basic box plot showing statistical distribution of multiple groups using altair. |
| 33 | +
|
| 34 | + Args: |
| 35 | + data: Input DataFrame with required columns |
| 36 | + values: Column name containing numeric values |
| 37 | + groups: Column name containing group categories |
| 38 | + title: Plot title (optional) |
| 39 | + xlabel: Custom x-axis label (optional, defaults to groups column name) |
| 40 | + ylabel: Custom y-axis label (optional, defaults to values column name) |
| 41 | + color_scheme: Color scheme for boxes (default: 'set2') |
| 42 | + width: Figure width in pixels (default: 800) |
| 43 | + height: Figure height in pixels (default: 450) |
| 44 | + **kwargs: Additional parameters for altair chart configuration |
| 45 | +
|
| 46 | + Returns: |
| 47 | + Altair Chart object |
| 48 | +
|
| 49 | + Raises: |
| 50 | + ValueError: If data is empty |
| 51 | + KeyError: If required columns not found |
| 52 | +
|
| 53 | + Example: |
| 54 | + >>> data = pd.DataFrame({ |
| 55 | + ... 'Group': ['A', 'A', 'B', 'B', 'C', 'C'], |
| 56 | + ... 'Value': [1, 2, 2, 3, 3, 4] |
| 57 | + ... }) |
| 58 | + >>> chart = create_plot(data, values='Value', groups='Group') |
| 59 | + """ |
| 60 | + # Input validation |
| 61 | + if data.empty: |
| 62 | + raise ValueError("Data cannot be empty") |
| 63 | + |
| 64 | + # Check required columns |
| 65 | + for col in [values, groups]: |
| 66 | + if col not in data.columns: |
| 67 | + available = ", ".join(data.columns) |
| 68 | + raise KeyError(f"Column '{col}' not found. Available columns: {available}") |
| 69 | + |
| 70 | + # Create the box plot using Altair's mark_boxplot |
| 71 | + base = ( |
| 72 | + alt.Chart(data) |
| 73 | + .mark_boxplot( |
| 74 | + extent=1.5, # 1.5 * IQR for whiskers |
| 75 | + outliers=True, |
| 76 | + size=40, |
| 77 | + opacity=0.7, |
| 78 | + ) |
| 79 | + .encode( |
| 80 | + x=alt.X( |
| 81 | + f"{groups}:N", |
| 82 | + title=xlabel or groups, |
| 83 | + axis=alt.Axis(labelAngle=0 if data[groups].nunique() <= 5 else -45, labelLimit=200), |
| 84 | + ), |
| 85 | + y=alt.Y(f"{values}:Q", title=ylabel or values, scale=alt.Scale(zero=False)), |
| 86 | + color=alt.Color( |
| 87 | + f"{groups}:N", |
| 88 | + scale=alt.Scale(scheme=color_scheme), |
| 89 | + legend=None, # Hide legend as it's redundant with x-axis |
| 90 | + ), |
| 91 | + tooltip=[ |
| 92 | + alt.Tooltip(f"{groups}:N", title="Group"), |
| 93 | + alt.Tooltip(f"count({values}):Q", title="Count"), |
| 94 | + alt.Tooltip(f"min({values}):Q", title="Min", format=".2f"), |
| 95 | + alt.Tooltip(f"q1({values}):Q", title="Q1", format=".2f"), |
| 96 | + alt.Tooltip(f"median({values}):Q", title="Median", format=".2f"), |
| 97 | + alt.Tooltip(f"q3({values}):Q", title="Q3", format=".2f"), |
| 98 | + alt.Tooltip(f"max({values}):Q", title="Max", format=".2f"), |
| 99 | + ], |
| 100 | + ) |
| 101 | + ) |
| 102 | + |
| 103 | + # Add sample size annotations |
| 104 | + text = ( |
| 105 | + alt.Chart(data) |
| 106 | + .mark_text(align="center", baseline="top", dy=10, fontSize=10, opacity=0.7) |
| 107 | + .encode(x=alt.X(f"{groups}:N"), y=alt.Y(f"min({values}):Q"), text=alt.Text("count():Q", format="d")) |
| 108 | + .transform_aggregate(count="count()", groupby=[groups]) |
| 109 | + ) |
| 110 | + |
| 111 | + # Combine box plot with annotations |
| 112 | + chart = ( |
| 113 | + (base + text) |
| 114 | + .properties( |
| 115 | + width=width, |
| 116 | + height=height, |
| 117 | + title=alt.TitleParams(text=title or "Box Plot Distribution", fontSize=16, anchor="middle"), |
| 118 | + ) |
| 119 | + .configure_view(strokeWidth=0) |
| 120 | + .configure_axis(grid=True, gridOpacity=0.3, gridDash=[3, 3], domainWidth=1, tickWidth=1) |
| 121 | + .configure_boxplot( |
| 122 | + median={"color": "red", "strokeWidth": 2}, |
| 123 | + box={"strokeWidth": 1.5}, |
| 124 | + outliers={"fill": "red", "fillOpacity": 0.5, "size": 50}, |
| 125 | + ) |
| 126 | + ) |
| 127 | + |
| 128 | + return chart |
| 129 | + |
| 130 | + |
| 131 | +if __name__ == "__main__": |
| 132 | + # Sample data for testing with different distributions per group |
| 133 | + np.random.seed(42) # For reproducibility |
| 134 | + |
| 135 | + # Generate sample data with 4 groups |
| 136 | + data_dict = {"Group": [], "Value": []} |
| 137 | + |
| 138 | + # Group A: Normal distribution, mean=50, std=10 |
| 139 | + group_a_data = np.random.normal(50, 10, 40) |
| 140 | + # Add some outliers |
| 141 | + group_a_data = np.append(group_a_data, [80, 85, 15]) |
| 142 | + |
| 143 | + # Group B: Normal distribution, mean=60, std=15 |
| 144 | + group_b_data = np.random.normal(60, 15, 35) |
| 145 | + # Add outliers |
| 146 | + group_b_data = np.append(group_b_data, [100, 10]) |
| 147 | + |
| 148 | + # Group C: Normal distribution, mean=45, std=8 |
| 149 | + group_c_data = np.random.normal(45, 8, 45) |
| 150 | + |
| 151 | + # Group D: Skewed distribution |
| 152 | + group_d_data = np.random.gamma(2, 2, 30) + 40 |
| 153 | + # Add outliers |
| 154 | + group_d_data = np.append(group_d_data, [75, 78, 20]) |
| 155 | + |
| 156 | + # Combine all data |
| 157 | + for group, values in zip( |
| 158 | + ["Group A", "Group B", "Group C", "Group D"], |
| 159 | + [group_a_data, group_b_data, group_c_data, group_d_data], |
| 160 | + strict=False, |
| 161 | + ): |
| 162 | + data_dict["Group"].extend([group] * len(values)) |
| 163 | + data_dict["Value"].extend(values) |
| 164 | + |
| 165 | + data = pd.DataFrame(data_dict) |
| 166 | + |
| 167 | + # Create plot |
| 168 | + chart = create_plot( |
| 169 | + data, |
| 170 | + values="Value", |
| 171 | + groups="Group", |
| 172 | + title="Statistical Distribution Comparison Across Groups", |
| 173 | + ylabel="Measurement Value", |
| 174 | + xlabel="Categories", |
| 175 | + ) |
| 176 | + |
| 177 | + # Save as PNG |
| 178 | + chart.save("plot.png", scale_factor=2.0) |
| 179 | + print("Plot saved to plot.png") |
0 commit comments