|
| 1 | +""" pyplots.ai |
| 2 | +confusion-matrix: Confusion Matrix Heatmap |
| 3 | +Library: bokeh 3.8.1 | Python 3.13.11 |
| 4 | +Quality: 96/100 | Created: 2025-12-26 |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from bokeh.io import export_png, output_file, save |
| 9 | +from bokeh.models import ColorBar, ColumnDataSource, LabelSet, LinearColorMapper |
| 10 | +from bokeh.plotting import figure |
| 11 | +from bokeh.transform import transform |
| 12 | + |
| 13 | + |
| 14 | +# Data - Multi-class classification results for a sentiment analysis model |
| 15 | +np.random.seed(42) |
| 16 | + |
| 17 | +class_names = ["Negative", "Neutral", "Positive", "Very Positive"] |
| 18 | + |
| 19 | +# Simulated confusion matrix with realistic patterns: |
| 20 | +# - Good diagonal (correct predictions) |
| 21 | +# - Adjacent classes more likely to be confused |
| 22 | +# - Some class imbalance |
| 23 | +confusion = np.array( |
| 24 | + [ |
| 25 | + [142, 23, 8, 2], # Negative: mostly correct, some confused with Neutral |
| 26 | + [18, 98, 31, 5], # Neutral: often confused with adjacent classes |
| 27 | + [5, 28, 156, 24], # Positive: good accuracy, some confusion with Neutral/Very Positive |
| 28 | + [1, 4, 19, 86], # Very Positive: smaller class, good precision |
| 29 | + ] |
| 30 | +) |
| 31 | + |
| 32 | +# Prepare data for Bokeh heatmap using rect glyphs |
| 33 | +x_coords = [] |
| 34 | +y_coords = [] |
| 35 | +values = [] |
| 36 | +text_labels = [] |
| 37 | + |
| 38 | +for i, true_class in enumerate(class_names): |
| 39 | + for j, pred_class in enumerate(class_names): |
| 40 | + x_coords.append(pred_class) |
| 41 | + y_coords.append(true_class) |
| 42 | + val = confusion[i, j] |
| 43 | + values.append(val) |
| 44 | + text_labels.append(str(val)) |
| 45 | + |
| 46 | +source = ColumnDataSource(data={"x": x_coords, "y": y_coords, "value": values, "text": text_labels}) |
| 47 | + |
| 48 | +# Color mapping - Blues sequential palette for counts |
| 49 | +colors = ["#f7fbff", "#deebf7", "#c6dbef", "#9ecae1", "#6baed6", "#4292c6", "#2171b5", "#08519c", "#08306b"] |
| 50 | +mapper = LinearColorMapper(palette=colors, low=0, high=max(values)) |
| 51 | + |
| 52 | +# Create figure - Square format works better for confusion matrices |
| 53 | +p = figure( |
| 54 | + width=3600, |
| 55 | + height=3600, |
| 56 | + title="confusion-matrix · bokeh · pyplots.ai", |
| 57 | + x_range=class_names, |
| 58 | + y_range=list(reversed(class_names)), # Reverse to have first class at top |
| 59 | + x_axis_label="Predicted Label", |
| 60 | + y_axis_label="True Label", |
| 61 | + tools="", |
| 62 | + toolbar_location=None, |
| 63 | +) |
| 64 | + |
| 65 | +# Draw heatmap cells using rect |
| 66 | +p.rect( |
| 67 | + x="x", |
| 68 | + y="y", |
| 69 | + width=1, |
| 70 | + height=1, |
| 71 | + source=source, |
| 72 | + fill_color=transform("value", mapper), |
| 73 | + line_color="#FFFFFF", |
| 74 | + line_width=3, |
| 75 | +) |
| 76 | + |
| 77 | +# Add text annotations for cell values |
| 78 | +# Calculate contrasting colors for text (white on dark, black on light) |
| 79 | +text_colors = [] |
| 80 | +for val in values: |
| 81 | + # Use white text on darker cells (higher values) |
| 82 | + if val > max(values) * 0.5: |
| 83 | + text_colors.append("#FFFFFF") |
| 84 | + else: |
| 85 | + text_colors.append("#08306b") |
| 86 | + |
| 87 | +source.data["text_color"] = text_colors |
| 88 | + |
| 89 | +labels = LabelSet( |
| 90 | + x="x", |
| 91 | + y="y", |
| 92 | + text="text", |
| 93 | + text_color="text_color", |
| 94 | + text_font_size="32pt", |
| 95 | + text_font_style="bold", |
| 96 | + text_align="center", |
| 97 | + text_baseline="middle", |
| 98 | + source=source, |
| 99 | +) |
| 100 | +p.add_layout(labels) |
| 101 | + |
| 102 | +# Style the figure for large canvas |
| 103 | +p.title.text_font_size = "36pt" |
| 104 | +p.title.text_font_style = "bold" |
| 105 | +p.title.align = "center" |
| 106 | + |
| 107 | +p.xaxis.axis_label_text_font_size = "28pt" |
| 108 | +p.yaxis.axis_label_text_font_size = "28pt" |
| 109 | +p.xaxis.major_label_text_font_size = "24pt" |
| 110 | +p.yaxis.major_label_text_font_size = "24pt" |
| 111 | + |
| 112 | +# Axis styling |
| 113 | +p.xaxis.axis_line_width = 2 |
| 114 | +p.yaxis.axis_line_width = 2 |
| 115 | +p.xaxis.major_tick_line_width = 2 |
| 116 | +p.yaxis.major_tick_line_width = 2 |
| 117 | +p.xaxis.major_label_orientation = 0.4 # Slight angle for readability |
| 118 | + |
| 119 | +# Remove grid for cleaner heatmap look |
| 120 | +p.xgrid.grid_line_color = None |
| 121 | +p.ygrid.grid_line_color = None |
| 122 | + |
| 123 | +# Add colorbar |
| 124 | +color_bar = ColorBar( |
| 125 | + color_mapper=mapper, |
| 126 | + location=(0, 0), |
| 127 | + title="Count", |
| 128 | + title_text_font_size="22pt", |
| 129 | + major_label_text_font_size="18pt", |
| 130 | + label_standoff=12, |
| 131 | + bar_line_color="#08306b", |
| 132 | + bar_line_width=2, |
| 133 | + width=30, |
| 134 | + padding=40, |
| 135 | +) |
| 136 | +p.add_layout(color_bar, "right") |
| 137 | + |
| 138 | +# Adjust overall padding |
| 139 | +p.min_border_left = 150 |
| 140 | +p.min_border_right = 150 |
| 141 | +p.min_border_top = 100 |
| 142 | +p.min_border_bottom = 150 |
| 143 | + |
| 144 | +# Save outputs |
| 145 | +export_png(p, filename="plot.png") |
| 146 | + |
| 147 | +# Also save interactive HTML |
| 148 | +output_file("plot.html") |
| 149 | +save(p) |
0 commit comments