1+ """
2+ heatmap-correlation: Correlation matrix heatmap showing pairwise correlations between variables
3+ Library: matplotlib
4+ """
5+
6+ import matplotlib .pyplot as plt
7+ import numpy as np
8+ import pandas as pd
9+ from matplotlib .figure import Figure
10+ from typing import TYPE_CHECKING , Optional
11+
12+ if TYPE_CHECKING :
13+ from matplotlib .figure import Figure
14+
15+
16+ def create_plot (
17+ data : pd .DataFrame ,
18+ cmap : str = 'RdBu_r' ,
19+ annot : bool = True ,
20+ fmt : str = '.2f' ,
21+ vmin : float = - 1.0 ,
22+ vmax : float = 1.0 ,
23+ cbar_label : str = 'Correlation' ,
24+ title : Optional [str ] = None ,
25+ ** kwargs
26+ ) -> Figure :
27+ """
28+ Create a correlation matrix heatmap showing pairwise correlations between variables.
29+
30+ Args:
31+ data: Input DataFrame with numeric columns to correlate
32+ cmap: Colormap for the heatmap (diverging)
33+ annot: Whether to show values in cells
34+ fmt: Format for cell annotations
35+ vmin: Minimum value for color scale
36+ vmax: Maximum value for color scale
37+ cbar_label: Label for color bar
38+ title: Optional plot title
39+ **kwargs: Additional parameters
40+
41+ Returns:
42+ Matplotlib Figure object
43+
44+ Raises:
45+ ValueError: If data is empty or has no numeric columns
46+ KeyError: If required columns not found
47+
48+ Example:
49+ >>> data = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6], 'C': [7, 8, 9]})
50+ >>> fig = create_plot(data, cmap='coolwarm')
51+ """
52+ # Input validation
53+ if data .empty :
54+ raise ValueError ("Data cannot be empty" )
55+
56+ # Select only numeric columns
57+ numeric_data = data .select_dtypes (include = [np .number ])
58+
59+ if numeric_data .empty :
60+ raise ValueError ("Data must contain at least one numeric column" )
61+
62+ if len (numeric_data .columns ) < 2 :
63+ raise ValueError ("Data must contain at least 2 numeric columns for correlation" )
64+
65+ # Calculate correlation matrix
66+ corr_matrix = numeric_data .corr ()
67+
68+ # Create figure
69+ fig , ax = plt .subplots (figsize = (16 , 9 ))
70+
71+ # Create heatmap using imshow
72+ im = ax .imshow (corr_matrix .values , cmap = cmap , vmin = vmin , vmax = vmax , aspect = 'auto' )
73+
74+ # Set ticks and labels
75+ ax .set_xticks (np .arange (len (corr_matrix .columns )))
76+ ax .set_yticks (np .arange (len (corr_matrix .columns )))
77+ ax .set_xticklabels (corr_matrix .columns )
78+ ax .set_yticklabels (corr_matrix .columns )
79+
80+ # Rotate the tick labels for better readability
81+ plt .setp (ax .get_xticklabels (), rotation = 45 , ha = "right" , rotation_mode = "anchor" )
82+
83+ # Add text annotations if requested
84+ if annot :
85+ for i in range (len (corr_matrix .columns )):
86+ for j in range (len (corr_matrix .columns )):
87+ value = corr_matrix .iloc [i , j ]
88+ text = ax .text (j , i , format (value , fmt ),
89+ ha = "center" , va = "center" ,
90+ color = "white" if abs (value ) > 0.7 else "black" ,
91+ fontsize = 10 )
92+
93+ # Add colorbar
94+ cbar = plt .colorbar (im , ax = ax )
95+ cbar .set_label (cbar_label , rotation = 270 , labelpad = 15 )
96+
97+ # Styling
98+ ax .set_xlabel ('' )
99+ ax .set_ylabel ('' )
100+
101+ # Add title if provided
102+ if title :
103+ ax .set_title (title , pad = 20 )
104+
105+ # Add subtle grid for cell boundaries
106+ ax .set_xticks (np .arange (len (corr_matrix .columns ) + 1 ) - 0.5 , minor = True )
107+ ax .set_yticks (np .arange (len (corr_matrix .columns ) + 1 ) - 0.5 , minor = True )
108+ ax .grid (which = "minor" , color = "gray" , linestyle = '-' , linewidth = 0.3 , alpha = 0.3 )
109+ ax .tick_params (which = "minor" , size = 0 )
110+
111+ # Layout
112+ plt .tight_layout ()
113+
114+ return fig
115+
116+
117+ if __name__ == '__main__' :
118+ # Sample data for testing
119+ import numpy as np
120+
121+ np .random .seed (42 )
122+ n = 100
123+
124+ data = pd .DataFrame ({
125+ 'Temperature' : np .random .normal (20 , 5 , n ),
126+ 'Humidity' : np .random .normal (60 , 10 , n ),
127+ 'Pressure' : np .random .normal (1013 , 10 , n ),
128+ 'Wind Speed' : np .random .normal (10 , 3 , n ),
129+ 'Rainfall' : np .random .normal (5 , 2 , n )
130+ })
131+
132+ # Add correlations
133+ data ['Humidity' ] = 100 - data ['Temperature' ] * 1.5 + np .random .normal (0 , 5 , n )
134+ data ['Rainfall' ] = data ['Humidity' ] * 0.1 + np .random .normal (0 , 1 , n )
135+ data ['Wind Speed' ] = 15 - data ['Pressure' ] * 0.01 + np .random .normal (0 , 2 , n )
136+
137+ # Create plot
138+ fig = create_plot (data )
139+
140+ # Save - ALWAYS use 'plot.png'!
141+ plt .savefig ('plot.png' , dpi = 300 , bbox_inches = 'tight' )
142+ print ("Plot saved to plot.png" )
0 commit comments