Skip to content

Commit fe35559

Browse files
feat(matplotlib): implement heatmap-correlation
- Add correlation matrix heatmap visualization - Display correlation coefficients with diverging color scheme - Include value annotations in cells - Add color bar with scale from -1 to 1 - Implement clear cell boundaries and readable labels Parent Issue: #53 Sub-Issue: #54
1 parent 989c85d commit fe35559

2 files changed

Lines changed: 209 additions & 0 deletions

File tree

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
heatmap-correlation: Correlation matrix heatmap showing pairwise correlations between variables
3+
Library: matplotlib
4+
"""
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import pandas as pd
9+
from matplotlib.figure import Figure
10+
from typing import TYPE_CHECKING, Optional
11+
12+
if TYPE_CHECKING:
13+
from matplotlib.figure import Figure
14+
15+
16+
def create_plot(
17+
data: pd.DataFrame,
18+
cmap: str = 'RdBu_r',
19+
annot: bool = True,
20+
fmt: str = '.2f',
21+
vmin: float = -1.0,
22+
vmax: float = 1.0,
23+
cbar_label: str = 'Correlation',
24+
title: Optional[str] = None,
25+
**kwargs
26+
) -> Figure:
27+
"""
28+
Create a correlation matrix heatmap showing pairwise correlations between variables.
29+
30+
Args:
31+
data: Input DataFrame with numeric columns to correlate
32+
cmap: Colormap for the heatmap (diverging)
33+
annot: Whether to show values in cells
34+
fmt: Format for cell annotations
35+
vmin: Minimum value for color scale
36+
vmax: Maximum value for color scale
37+
cbar_label: Label for color bar
38+
title: Optional plot title
39+
**kwargs: Additional parameters
40+
41+
Returns:
42+
Matplotlib Figure object
43+
44+
Raises:
45+
ValueError: If data is empty or has no numeric columns
46+
KeyError: If required columns not found
47+
48+
Example:
49+
>>> data = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6], 'C': [7, 8, 9]})
50+
>>> fig = create_plot(data, cmap='coolwarm')
51+
"""
52+
# Input validation
53+
if data.empty:
54+
raise ValueError("Data cannot be empty")
55+
56+
# Select only numeric columns
57+
numeric_data = data.select_dtypes(include=[np.number])
58+
59+
if numeric_data.empty:
60+
raise ValueError("Data must contain at least one numeric column")
61+
62+
if len(numeric_data.columns) < 2:
63+
raise ValueError("Data must contain at least 2 numeric columns for correlation")
64+
65+
# Calculate correlation matrix
66+
corr_matrix = numeric_data.corr()
67+
68+
# Create figure
69+
fig, ax = plt.subplots(figsize=(16, 9))
70+
71+
# Create heatmap using imshow
72+
im = ax.imshow(corr_matrix.values, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')
73+
74+
# Set ticks and labels
75+
ax.set_xticks(np.arange(len(corr_matrix.columns)))
76+
ax.set_yticks(np.arange(len(corr_matrix.columns)))
77+
ax.set_xticklabels(corr_matrix.columns)
78+
ax.set_yticklabels(corr_matrix.columns)
79+
80+
# Rotate the tick labels for better readability
81+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
82+
83+
# Add text annotations if requested
84+
if annot:
85+
for i in range(len(corr_matrix.columns)):
86+
for j in range(len(corr_matrix.columns)):
87+
value = corr_matrix.iloc[i, j]
88+
text = ax.text(j, i, format(value, fmt),
89+
ha="center", va="center",
90+
color="white" if abs(value) > 0.7 else "black",
91+
fontsize=10)
92+
93+
# Add colorbar
94+
cbar = plt.colorbar(im, ax=ax)
95+
cbar.set_label(cbar_label, rotation=270, labelpad=15)
96+
97+
# Styling
98+
ax.set_xlabel('')
99+
ax.set_ylabel('')
100+
101+
# Add title if provided
102+
if title:
103+
ax.set_title(title, pad=20)
104+
105+
# Add subtle grid for cell boundaries
106+
ax.set_xticks(np.arange(len(corr_matrix.columns) + 1) - 0.5, minor=True)
107+
ax.set_yticks(np.arange(len(corr_matrix.columns) + 1) - 0.5, minor=True)
108+
ax.grid(which="minor", color="gray", linestyle='-', linewidth=0.3, alpha=0.3)
109+
ax.tick_params(which="minor", size=0)
110+
111+
# Layout
112+
plt.tight_layout()
113+
114+
return fig
115+
116+
117+
if __name__ == '__main__':
118+
# Sample data for testing
119+
import numpy as np
120+
121+
np.random.seed(42)
122+
n = 100
123+
124+
data = pd.DataFrame({
125+
'Temperature': np.random.normal(20, 5, n),
126+
'Humidity': np.random.normal(60, 10, n),
127+
'Pressure': np.random.normal(1013, 10, n),
128+
'Wind Speed': np.random.normal(10, 3, n),
129+
'Rainfall': np.random.normal(5, 2, n)
130+
})
131+
132+
# Add correlations
133+
data['Humidity'] = 100 - data['Temperature'] * 1.5 + np.random.normal(0, 5, n)
134+
data['Rainfall'] = data['Humidity'] * 0.1 + np.random.normal(0, 1, n)
135+
data['Wind Speed'] = 15 - data['Pressure'] * 0.01 + np.random.normal(0, 2, n)
136+
137+
# Create plot
138+
fig = create_plot(data)
139+
140+
# Save - ALWAYS use 'plot.png'!
141+
plt.savefig('plot.png', dpi=300, bbox_inches='tight')
142+
print("Plot saved to plot.png")

specs/heatmap-correlation.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# heatmap-correlation
2+
3+
Correlation matrix heatmap showing pairwise correlations between variables
4+
5+
## Requirements
6+
7+
### Data
8+
- DataFrame with 4-6 numeric variables suitable for correlation analysis
9+
- Compute correlation matrix using pandas `.corr()` method
10+
11+
### Visual
12+
- Display correlation coefficients (-1 to 1) as colors
13+
- Use diverging color scheme (blue-white-red)
14+
- Show correlation values in cells (formatted to 2 decimal places)
15+
- Variable names as axis labels
16+
- Color bar legend showing the scale
17+
- Clear cell boundaries
18+
- Readable text annotations
19+
- Color scale centered at 0
20+
21+
### Parameters
22+
23+
#### Required
24+
- `data`: pd.DataFrame - Input data with numeric columns to correlate
25+
26+
#### Optional
27+
- `cmap`: str = 'RdBu_r' - Colormap for the heatmap (diverging)
28+
- `annot`: bool = True - Whether to show values in cells
29+
- `fmt`: str = '.2f' - Format for cell annotations
30+
- `vmin`: float = -1.0 - Minimum value for color scale
31+
- `vmax`: float = 1.0 - Maximum value for color scale
32+
- `cbar_label`: str = 'Correlation' - Label for color bar
33+
- `title`: str = None - Optional plot title
34+
35+
## Sample Data
36+
37+
```python
38+
import pandas as pd
39+
import numpy as np
40+
41+
np.random.seed(42)
42+
n = 100
43+
44+
data = pd.DataFrame({
45+
'Temperature': np.random.normal(20, 5, n),
46+
'Humidity': np.random.normal(60, 10, n),
47+
'Pressure': np.random.normal(1013, 10, n),
48+
'Wind Speed': np.random.normal(10, 3, n),
49+
'Rainfall': np.random.normal(5, 2, n)
50+
})
51+
52+
# Add correlations
53+
data['Humidity'] = 100 - data['Temperature'] * 1.5 + np.random.normal(0, 5, n)
54+
data['Rainfall'] = data['Humidity'] * 0.1 + np.random.normal(0, 1, n)
55+
data['Wind Speed'] = 15 - data['Pressure'] * 0.01 + np.random.normal(0, 2, n)
56+
```
57+
58+
## Expected Output
59+
60+
A heatmap showing:
61+
- 5x5 correlation matrix
62+
- Values from -1 to 1 displayed in each cell
63+
- Blue for negative correlations
64+
- Red for positive correlations
65+
- White for no correlation (0)
66+
- Color bar on the right side
67+
- Variable names on both axes

0 commit comments

Comments
 (0)