Skip to content

Commit d74a205

Browse files
committed
waveform_ms calculation
1 parent 2aabe50 commit d74a205

1 file changed

Lines changed: 23 additions & 3 deletions

File tree

batbot/spectrogram/__init__.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,25 @@ def generate_waveplot(
243243

244244
return waveplot
245245

246+
def get_waveform_data_ms(waveform, sample_rate, hop_length=16):
247+
"""
248+
Returns time in milliseconds and the raw min/max amplitude values.
249+
"""
250+
# 1. Calculate min/max envelopes
251+
temp = np.pad(waveform, hop_length // 2, mode='edge')
252+
views = np.lib.stride_tricks.sliding_window_view(temp, (hop_length,))[::hop_length]
253+
254+
bin_mins = np.min(views, axis=1)
255+
bin_maxs = np.max(views, axis=1)
256+
257+
# 2. Calculate time in milliseconds
258+
# (index * hop_length) = total samples
259+
# (samples / sample_rate) = seconds
260+
# (seconds * 1000) = milliseconds
261+
times_ms = (np.arange(len(bin_mins)) * hop_length) / sample_rate * 1000
262+
263+
return times_ms, bin_mins, bin_maxs
264+
246265

247266
# @lp
248267
def load_stft(
@@ -306,14 +325,15 @@ def load_stft(
306325
else:
307326
waveplot = generate_waveplot(waveform, stft_db, hop_length=hop_length)
308327

328+
waveform_ms = get_waveform_data_ms(waveform, sample_rate=sr, hop_length=hop_length)
309329
# Estimate maximum frequency band containing data based on original sample rate
310330
# Only data up to this maximum band should be used when computing statistics
311331
max_band_idx = min((int(np.where(bands < orig_sr / 2.02)[0][-1]), len(bands) - 1))
312332
# set non-physical noise above the max band to a minimum value
313333
if max_band_idx < len(bands) - 1:
314334
stft_db[max_band_idx + 1 :, :] = np.min(stft_db[: max_band_idx + 1, :])
315335

316-
return stft_db, waveplot, sr, bands, duration, min_index, time_vec, orig_sr, max_band_idx
336+
return stft_db, waveplot, sr, bands, duration, min_index, time_vec, orig_sr, max_band_idx, waveform_ms
317337

318338

319339
# @lp
@@ -1468,7 +1488,7 @@ def compute_wrapper(
14681488
with warnings.catch_warnings():
14691489
warnings.simplefilter('ignore', category=DeprecationWarning)
14701490
# ignore warning due to aifc deprecation
1471-
stft_db, waveplot, sr, bands, duration, freq_offset, time_vec, orig_sr, max_band_idx = (
1491+
stft_db, waveplot, sr, bands, duration, freq_offset, time_vec, orig_sr, max_band_idx, waveform_ms = (
14721492
load_stft(wav_filepath, fast_mode=fast_mode)
14731493
)
14741494

@@ -1738,7 +1758,7 @@ def compute_wrapper(
17381758
segments['waveplot'].append(segment_waveplot)
17391759
# convert to JSON serializable datatype and add to metadata if segment_waves is True
17401760
if segment_waves:
1741-
segment_waveplot = segment_waveplot.tolist()
1761+
segment_waveplot = waveform_ms[:, start + trim_begin : start + trim_end]
17421762
metadata_waveplot = {
17431763
"waveplot": segment_waveplot,
17441764
}

0 commit comments

Comments
 (0)