|
| 1 | +""" pyplots.ai |
| 2 | +scatter-matrix: Scatter Plot Matrix |
| 3 | +Library: pygal 3.1.0 | Python 3.13.11 |
| 4 | +Quality: 90/100 | Created: 2025-12-26 |
| 5 | +""" |
| 6 | + |
| 7 | +from io import BytesIO |
| 8 | + |
| 9 | +import cairosvg |
| 10 | +import numpy as np |
| 11 | +import pygal |
| 12 | +from PIL import Image, ImageDraw, ImageFont |
| 13 | +from pygal.style import Style |
| 14 | + |
| 15 | + |
| 16 | +# Data - Iris-like dataset with 4 variables |
| 17 | +np.random.seed(42) |
| 18 | +n_samples = 100 |
| 19 | + |
| 20 | +# Create correlated multivariate data with both positive and negative correlations |
| 21 | +base = np.random.randn(n_samples) |
| 22 | +sepal_length = 5.8 + base * 0.8 + np.random.randn(n_samples) * 0.3 |
| 23 | +sepal_width = 3.0 - base * 0.4 + np.random.randn(n_samples) * 0.25 # Negative correlation with sepal_length |
| 24 | +petal_length = 3.8 + base * 1.5 + np.random.randn(n_samples) * 0.4 |
| 25 | +petal_width = 1.2 + base * 0.6 + np.random.randn(n_samples) * 0.2 |
| 26 | + |
| 27 | +variables = { |
| 28 | + "Sepal Length": sepal_length, |
| 29 | + "Sepal Width": sepal_width, |
| 30 | + "Petal Length": petal_length, |
| 31 | + "Petal Width": petal_width, |
| 32 | +} |
| 33 | +var_names = list(variables.keys()) |
| 34 | +n_vars = len(var_names) |
| 35 | + |
| 36 | +# Style configuration with improved transparency for overlapping points |
| 37 | +custom_style = Style( |
| 38 | + background="white", |
| 39 | + plot_background="#f8f8f8", |
| 40 | + foreground="#333", |
| 41 | + foreground_strong="#333", |
| 42 | + foreground_subtle="#666", |
| 43 | + colors=("#306998", "#FFD43B", "#4B8BBE", "#FFE873"), |
| 44 | + title_font_size=28, |
| 45 | + label_font_size=18, |
| 46 | + major_label_font_size=16, |
| 47 | + legend_font_size=16, |
| 48 | + value_font_size=14, |
| 49 | + opacity=0.55, |
| 50 | + opacity_hover=0.85, |
| 51 | +) |
| 52 | + |
| 53 | +# Canvas dimensions |
| 54 | +total_width = 3600 |
| 55 | +total_height = 3600 |
| 56 | +margin_top = 120 |
| 57 | +margin_bottom = 120 |
| 58 | +margin_left = 120 |
| 59 | +margin_right = 50 |
| 60 | +plot_area_width = total_width - margin_left - margin_right |
| 61 | +plot_area_height = total_height - margin_top - margin_bottom |
| 62 | +cell_size = min(plot_area_width, plot_area_height) // n_vars |
| 63 | +gap = 10 |
| 64 | + |
| 65 | +# Create composite image |
| 66 | +composite = Image.new("RGB", (total_width, total_height), "white") |
| 67 | + |
| 68 | +# Render each cell chart and paste into composite |
| 69 | +for i in range(n_vars): |
| 70 | + for j in range(n_vars): |
| 71 | + x_pos = margin_left + j * cell_size + gap // 2 |
| 72 | + y_pos = margin_top + i * cell_size + gap // 2 |
| 73 | + inner_size = cell_size - gap |
| 74 | + |
| 75 | + var_x = var_names[j] |
| 76 | + var_y = var_names[i] |
| 77 | + |
| 78 | + if i == j: |
| 79 | + # Diagonal: Histogram |
| 80 | + chart = pygal.Histogram( |
| 81 | + width=inner_size, |
| 82 | + height=inner_size, |
| 83 | + style=custom_style, |
| 84 | + show_legend=False, |
| 85 | + show_x_labels=(i == n_vars - 1), |
| 86 | + show_y_labels=(j == 0), |
| 87 | + x_label_rotation=0, |
| 88 | + show_minor_x_labels=False, |
| 89 | + show_minor_y_labels=False, |
| 90 | + margin_top=8, |
| 91 | + margin_right=8, |
| 92 | + margin_bottom=40 if i == n_vars - 1 else 8, |
| 93 | + margin_left=70 if j == 0 else 8, |
| 94 | + spacing=0, |
| 95 | + truncate_label=-1, |
| 96 | + ) |
| 97 | + |
| 98 | + # Create histogram data |
| 99 | + data = variables[var_x] |
| 100 | + hist, bin_edges = np.histogram(data, bins=12) |
| 101 | + hist_data = [(float(bin_edges[k]), float(bin_edges[k + 1]), float(hist[k])) for k in range(len(hist))] |
| 102 | + chart.add(var_x, hist_data) |
| 103 | + else: |
| 104 | + # Off-diagonal: Scatter plot with smaller dots and better transparency |
| 105 | + chart = pygal.XY( |
| 106 | + width=inner_size, |
| 107 | + height=inner_size, |
| 108 | + style=custom_style, |
| 109 | + show_legend=False, |
| 110 | + show_x_labels=(i == n_vars - 1), |
| 111 | + show_y_labels=(j == 0), |
| 112 | + x_label_rotation=0, |
| 113 | + show_minor_x_labels=False, |
| 114 | + show_minor_y_labels=False, |
| 115 | + margin_top=8, |
| 116 | + margin_right=8, |
| 117 | + margin_bottom=40 if i == n_vars - 1 else 8, |
| 118 | + margin_left=70 if j == 0 else 8, |
| 119 | + dots_size=7, |
| 120 | + stroke=False, |
| 121 | + truncate_label=-1, |
| 122 | + ) |
| 123 | + |
| 124 | + # Scatter data as (x, y) tuples |
| 125 | + x_data = variables[var_x] |
| 126 | + y_data = variables[var_y] |
| 127 | + scatter_data = [(float(x_data[k]), float(y_data[k])) for k in range(len(x_data))] |
| 128 | + chart.add("Data", scatter_data) |
| 129 | + |
| 130 | + # Render chart to PNG bytes |
| 131 | + svg_bytes = chart.render() |
| 132 | + png_bytes = cairosvg.svg2png(bytestring=svg_bytes, output_width=inner_size, output_height=inner_size) |
| 133 | + cell_image = Image.open(BytesIO(png_bytes)) |
| 134 | + |
| 135 | + # Paste into composite |
| 136 | + composite.paste(cell_image, (x_pos, y_pos)) |
| 137 | + |
| 138 | +# Add title and labels using PIL |
| 139 | +draw = ImageDraw.Draw(composite) |
| 140 | + |
| 141 | +# Try to use a nice font, fall back to default |
| 142 | +try: |
| 143 | + title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 56) |
| 144 | + label_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 36) |
| 145 | +except OSError: |
| 146 | + title_font = ImageFont.load_default() |
| 147 | + label_font = ImageFont.load_default() |
| 148 | + |
| 149 | +# Title |
| 150 | +title_text = "scatter-matrix · pygal · pyplots.ai" |
| 151 | +title_bbox = draw.textbbox((0, 0), title_text, font=title_font) |
| 152 | +title_width = title_bbox[2] - title_bbox[0] |
| 153 | +draw.text(((total_width - title_width) // 2, 35), title_text, fill="#333", font=title_font) |
| 154 | + |
| 155 | +# Variable labels along bottom and left |
| 156 | +for idx, var_name in enumerate(var_names): |
| 157 | + # Bottom labels |
| 158 | + x_label_pos = margin_left + idx * cell_size + cell_size // 2 |
| 159 | + y_label_pos = margin_top + n_vars * cell_size + 40 |
| 160 | + bbox = draw.textbbox((0, 0), var_name, font=label_font) |
| 161 | + text_width = bbox[2] - bbox[0] |
| 162 | + draw.text((x_label_pos - text_width // 2, y_label_pos), var_name, fill="#333", font=label_font) |
| 163 | + |
| 164 | + # Left labels (rotated - draw text vertically, positioned closer to plots) |
| 165 | + x_label_pos = 15 |
| 166 | + y_label_pos = margin_top + idx * cell_size + cell_size // 2 |
| 167 | + |
| 168 | + # Create rotated text image |
| 169 | + txt_img = Image.new("RGBA", (350, 60), (255, 255, 255, 0)) |
| 170 | + txt_draw = ImageDraw.Draw(txt_img) |
| 171 | + txt_draw.text((0, 0), var_name, fill="#333", font=label_font) |
| 172 | + txt_rotated = txt_img.rotate(90, expand=True) |
| 173 | + |
| 174 | + # Paste rotated text |
| 175 | + bbox = draw.textbbox((0, 0), var_name, font=label_font) |
| 176 | + text_height = bbox[2] - bbox[0] |
| 177 | + paste_y = y_label_pos - text_height // 2 |
| 178 | + composite.paste(txt_rotated, (x_label_pos, paste_y), txt_rotated) |
| 179 | + |
| 180 | +# Save output |
| 181 | +composite.save("plot.png", "PNG", dpi=(300, 300)) |
0 commit comments