-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplotnine.py
More file actions
189 lines (170 loc) · 6.43 KB
/
plotnine.py
File metadata and controls
189 lines (170 loc) · 6.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
""" pyplots.ai
spectrogram-mel: Mel-Spectrogram for Audio Analysis
Library: plotnine 0.15.3 | Python 3.14.3
Quality: 90/100 | Created: 2026-03-11
"""
import numpy as np
import pandas as pd
from plotnine import (
aes,
coord_cartesian,
element_blank,
element_rect,
element_text,
geom_raster,
geom_segment,
geom_text,
ggplot,
guide_colorbar,
guides,
labs,
scale_fill_gradientn,
scale_x_continuous,
scale_y_continuous,
theme,
theme_minimal,
)
from scipy.signal import stft
# Data - synthesize a 3-second audio signal with speech-like frequency components
np.random.seed(42)
sample_rate = 22050
duration = 3.0
n_samples = int(sample_rate * duration)
t = np.linspace(0, duration, n_samples, endpoint=False)
# Build a rich audio signal: fundamental + harmonics with time-varying amplitude
fundamental = 220
signal = (
0.6 * np.sin(2 * np.pi * fundamental * t) * np.exp(-0.3 * t)
+ 0.4 * np.sin(2 * np.pi * 440 * t) * (0.5 + 0.5 * np.sin(2 * np.pi * 1.5 * t))
+ 0.3 * np.sin(2 * np.pi * 880 * t) * np.exp(-0.5 * t)
+ 0.2 * np.sin(2 * np.pi * 1320 * t) * (1 - t / duration)
+ 0.15 * np.sin(2 * np.pi * 3300 * t) * np.exp(-1.0 * t)
+ 0.1 * np.random.randn(n_samples) * np.exp(-0.8 * t)
)
# Add a frequency sweep (chirp) from 500 to 4000 Hz in the middle section
chirp_mask = (t > 0.8) & (t < 2.0)
chirp_freq = 500 + (4000 - 500) * (t[chirp_mask] - 0.8) / 1.2
signal[chirp_mask] += 0.35 * np.sin(2 * np.pi * np.cumsum(chirp_freq) / sample_rate)
# STFT
n_fft = 2048
hop_length = 512
_, time_bins, Zxx = stft(signal, fs=sample_rate, nperseg=n_fft, noverlap=n_fft - hop_length)
power_spec = np.abs(Zxx) ** 2
# Mel filterbank
n_mels = 128
freq_bins = np.linspace(0, sample_rate / 2, power_spec.shape[0])
mel_low = 2595.0 * np.log10(1.0 + 0 / 700.0)
mel_high = 2595.0 * np.log10(1.0 + (sample_rate / 2) / 700.0)
mel_points = np.linspace(mel_low, mel_high, n_mels + 2)
hz_points = 700.0 * (10.0 ** (mel_points / 2595.0) - 1.0)
# Vectorized mel filterbank using numpy broadcasting
lower = hz_points[:-2, np.newaxis] # (n_mels, 1)
center = hz_points[1:-1, np.newaxis] # (n_mels, 1)
upper = hz_points[2:, np.newaxis] # (n_mels, 1)
freqs = freq_bins[np.newaxis, :] # (1, n_freq)
rising = np.where((freqs >= lower) & (freqs <= center) & (center != lower), (freqs - lower) / (center - lower), 0.0)
falling = np.where((freqs > center) & (freqs <= upper) & (upper != center), (upper - freqs) / (upper - center), 0.0)
filterbank = rising + falling
# Apply mel filterbank and convert to dB
mel_spec = filterbank @ power_spec
mel_spec_db = 10 * np.log10(np.maximum(mel_spec, 1e-10))
mel_spec_db -= mel_spec_db.max()
# Build long-form DataFrame with evenly-spaced mel band indices for smooth raster
mel_center_freqs = 700.0 * (10.0 ** (mel_points[1:-1] / 2595.0) - 1.0)
time_grid, mel_idx_grid = np.meshgrid(time_bins, np.arange(n_mels))
df = pd.DataFrame({"Time (s)": time_grid.ravel(), "mel_band": mel_idx_grid.ravel(), "Power (dB)": mel_spec_db.ravel()})
# Y-axis tick positions: map Hz values to mel band indices
y_ticks_hz = [128, 256, 512, 1024, 2048, 4096, 8000]
y_ticks_hz = [f for f in y_ticks_hz if f <= sample_rate / 2]
# Convert Hz to mel band index via interpolation
y_ticks_band = np.interp(y_ticks_hz, mel_center_freqs, np.arange(n_mels))
# Annotation data — grammar-of-graphics approach: data-driven geom layers
f0_band = float(np.interp(220, mel_center_freqs, np.arange(n_mels)))
h3_band = float(np.interp(880, mel_center_freqs, np.arange(n_mels)))
df_labels = pd.DataFrame(
{"x": [2.85, 2.85], "y": [f0_band, h3_band], "label": ["F\u2080", "3rd"], "color": ["#fcffa4", "#fb9b06"]}
)
df_reflines = pd.DataFrame(
{"x": [0.0, 0.0], "xend": [duration, duration], "y": [f0_band, h3_band], "yend": [f0_band, h3_band]}
)
# Plot — geom_raster for smooth spectrogram, data-driven geom_text/geom_segment for annotations
plot = (
ggplot(df, aes(x="Time (s)", y="mel_band", fill="Power (dB)"))
+ geom_raster(interpolate=True)
+ scale_fill_gradientn(
colors=[
"#000004",
"#1b0c41",
"#4a0c6b",
"#781c6d",
"#a52c60",
"#cf4446",
"#ed6925",
"#fb9b06",
"#f7d13d",
"#fcffa4",
],
name="Power (dB)",
)
+ guides(fill=guide_colorbar(nbin=256, display="raster"))
+ geom_text(
aes(x="x", y="y", label="label"),
data=df_labels.iloc[[0]],
inherit_aes=False,
color="#fcffa4",
size=11,
ha="right",
fontweight="bold",
alpha=0.85,
)
+ geom_text(
aes(x="x", y="y", label="label"),
data=df_labels.iloc[[1]],
inherit_aes=False,
color="#fb9b06",
size=9,
ha="right",
alpha=0.7,
)
+ geom_segment(
aes(x="x", xend="xend", y="y", yend="yend"),
data=df_reflines.iloc[[0]],
inherit_aes=False,
color="#fcffa4",
alpha=0.15,
size=0.4,
)
+ geom_segment(
aes(x="x", xend="xend", y="y", yend="yend"),
data=df_reflines.iloc[[1]],
inherit_aes=False,
color="#fb9b06",
alpha=0.12,
size=0.3,
)
+ scale_x_continuous(expand=(0, 0))
+ scale_y_continuous(breaks=y_ticks_band.tolist(), labels=[str(f) for f in y_ticks_hz], expand=(0, 0))
+ coord_cartesian(ylim=(0, n_mels - 1))
+ labs(x="Time (s)", y="Frequency (Hz)", title="spectrogram-mel \u00b7 plotnine \u00b7 pyplots.ai")
+ theme_minimal()
+ theme(
figure_size=(16, 9),
text=element_text(family="sans-serif"),
plot_title=element_text(size=24, ha="center", weight="bold", color="#e0e0e0", margin={"b": 8}),
axis_title_x=element_text(size=20, color="#cccccc", margin={"t": 10}),
axis_title_y=element_text(size=20, color="#cccccc", margin={"r": 8}),
axis_text_x=element_text(size=16, color="#aaaaaa"),
axis_text_y=element_text(size=16, color="#aaaaaa"),
legend_title=element_text(size=16, weight="bold", color="#cccccc"),
legend_text=element_text(size=14, color="#aaaaaa"),
legend_position="right",
legend_key_height=60,
legend_key_width=14,
panel_grid_major=element_blank(),
panel_grid_minor=element_blank(),
panel_background=element_rect(fill="#000004", color="none"),
plot_background=element_rect(fill="#0e0e1a", color="none"),
plot_margin=0.02,
)
)
plot.save("plot.png", dpi=300, verbose=False)