Skip to content

Commit a4d09d5

Browse files
authored
[FEATURE] Butterworth filter (#381)
* add Butterworth dataclass * add butterworth filtering in AudioData.get_value() * add AudioDataset.butter property * add butter parameter in Transform constructor * move filtering to a new AudioData.get_filtered_value() method * add butter filtering test * add Butterworth serialization test * add __hash__() method to Butterworth class * add AudioDataset butter tests
1 parent 7ad9dd3 commit a4d09d5

7 files changed

Lines changed: 248 additions & 8 deletions

File tree

src/osekit/core/audio_data.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from osekit.core.audio_item import AudioItem
2121
from osekit.core.base_data import BaseData
2222
from osekit.core.instrument import Instrument
23-
from osekit.utils.audio import Normalization, normalize
23+
from osekit.utils.audio import Butterworth, Normalization, normalize
2424

2525
if TYPE_CHECKING:
2626
from pathlib import Path
@@ -45,6 +45,7 @@ def __init__(
4545
instrument: Instrument | None = None,
4646
normalization: Normalization = Normalization.RAW,
4747
normalization_values: dict | None = None,
48+
butter: Butterworth | None = None,
4849
) -> None:
4950
"""Initialize an ``AudioData`` from a list of ``AudioItems``.
5051
@@ -67,13 +68,16 @@ def __init__(
6768
the wav audio data.
6869
normalization: Normalization
6970
The type of normalization to apply to the audio data.
71+
butter: Butterworth | None
72+
Butterworth filter to apply to the audio data.
7073
7174
"""
7275
super().__init__(items=items, begin=begin, end=end, name=name)
7376
self._set_sample_rate(sample_rate=sample_rate)
7477
self.instrument = instrument
7578
self.normalization = normalization
7679
self.normalization_values = normalization_values
80+
self.butter = butter
7781

7882
@property
7983
def nb_channels(self) -> int:
@@ -123,6 +127,15 @@ def normalization_values(self, value: dict | None) -> None:
123127
}
124128
)
125129

130+
@property
131+
def butter(self) -> Butterworth:
132+
"""The Butterworth filter to apply to the audio data."""
133+
return self._butter
134+
135+
@butter.setter
136+
def butter(self, value: Butterworth) -> None:
137+
self._butter = value
138+
126139
@classmethod
127140
def _make_item(
128141
cls,
@@ -178,7 +191,7 @@ def get_normalization_values(self) -> dict:
178191
"std": standard deviation used for z-score normalization
179192
180193
"""
181-
values = np.array(self.get_raw_value())
194+
values = np.array(self.get_filtered_value())
182195
self.normalization_values = {
183196
"mean": values.mean(),
184197
"peak": values.max(),
@@ -222,6 +235,22 @@ def get_raw_value(self) -> np.ndarray:
222235
"""
223236
return np.vstack(list(self.stream()))
224237

238+
def get_filtered_value(self) -> np.ndarray:
239+
"""Return the value of the audio data after filtering.
240+
241+
Returns
242+
-------
243+
np.ndarray:
244+
The value of the audio data filtered by the ``self.butter`` Butterworth filter.
245+
246+
"""
247+
output = self.get_raw_value()
248+
return (
249+
output
250+
if self.butter is None
251+
else self.butter.filter(sig=output, fs=self.sample_rate)
252+
)
253+
225254
@staticmethod
226255
def _flush(
227256
resampler: soxr.ResampleStream,
@@ -320,7 +349,7 @@ def get_value(self) -> np.ndarray:
320349
321350
"""
322351
return normalize(
323-
values=self.get_raw_value(),
352+
values=self.get_filtered_value(),
324353
normalization=self.normalization,
325354
**self.normalization_values,
326355
)
@@ -547,9 +576,13 @@ def to_dict(self) -> dict:
547576
None if self.instrument is None else self.instrument.to_dict()
548577
),
549578
}
579+
butter_dict = {
580+
"butter": (None if self.butter is None else self.butter.to_dict()),
581+
}
550582
return (
551583
base_dict
552584
| instrument_dict
585+
| butter_dict
553586
| {
554587
"sample_rate": self.sample_rate,
555588
"normalization": self.normalization.value,
@@ -595,6 +628,11 @@ def _from_base_dict(
595628
if dictionary["instrument"] is None
596629
else Instrument.from_dict(dictionary["instrument"])
597630
)
631+
butter = (
632+
None
633+
if "butter" not in dictionary or dictionary["butter"] is None
634+
else Butterworth.from_dict(dictionary["butter"])
635+
)
598636
return cls.from_files(
599637
files=files,
600638
begin=begin,
@@ -603,6 +641,7 @@ def _from_base_dict(
603641
sample_rate=dictionary["sample_rate"],
604642
normalization=Normalization(dictionary["normalization"]),
605643
normalization_values=dictionary["normalization_values"],
644+
butter=butter,
606645
)
607646

608647
@classmethod
@@ -641,6 +680,9 @@ def from_files(
641680
normalization: Normalization
642681
The type of normalization to apply to the audio data.
643682
683+
butter: Butterworth
684+
Butterworth filter to apply to the audio data.
685+
644686
Returns
645687
-------
646688
Self:

src/osekit/core/audio_dataset.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from osekit.core.audio_file import AudioFile
1414
from osekit.core.base_dataset import BaseDataset
1515
from osekit.core.json_serializer import deserialize_json
16-
from osekit.utils.audio import Normalization
16+
from osekit.utils.audio import Butterworth, Normalization
1717
from osekit.utils.multiprocess import multiprocess
1818

1919
if TYPE_CHECKING:
@@ -89,6 +89,17 @@ def normalization(self, normalization: Normalization) -> None:
8989
for data in self.data:
9090
data.normalization = normalization
9191

92+
@property
93+
def butter(self) -> Butterworth:
94+
"""Return the most frequent Butterworth filter among those of this dataset data."""
95+
butters = [data.butter for data in self.data]
96+
return max(set(butters), key=butters.count)
97+
98+
@butter.setter
99+
def butter(self, butter: Butterworth) -> None:
100+
for data in self.data:
101+
data.butter = butter
102+
92103
@property
93104
def instrument(self) -> Instrument | None:
94105
"""Instrument that can be used to get acoustic pressure from wav audio data."""
@@ -187,6 +198,7 @@ def from_folder( # noqa: PLR0913
187198
name: str | None = None,
188199
instrument: Instrument | None = None,
189200
normalization: Normalization = Normalization.RAW,
201+
butter: Butterworth | None = None,
190202
**kwargs, # noqa: ANN003
191203
) -> Self:
192204
"""Return an ``AudioDataset`` from a folder containing the audio files.
@@ -240,6 +252,8 @@ def from_folder( # noqa: PLR0913
240252
the wav audio data.
241253
normalization: Normalization
242254
The type of normalization to apply to the audio data.
255+
butter: Butterworth | None
256+
Butterworth filter to apply to the audio data.
243257
kwargs: any
244258
Keyword arguments passed to the ``BaseDataset.from_folder()`` classmethod.
245259
@@ -262,6 +276,7 @@ def from_folder( # noqa: PLR0913
262276
name=name,
263277
instrument=instrument,
264278
normalization=normalization,
279+
butter=butter,
265280
)
266281

267282
@classmethod
@@ -277,6 +292,7 @@ def from_files( # noqa: PLR0913
277292
sample_rate: float | None = None,
278293
instrument: Instrument | None = None,
279294
normalization: Normalization = Normalization.RAW,
295+
butter: Butterworth | None = None,
280296
) -> AudioDataset:
281297
"""Return an AudioDataset object from a list of AudioFiles.
282298
@@ -317,6 +333,8 @@ def from_files( # noqa: PLR0913
317333
the wav audio data.
318334
normalization: Normalization
319335
The type of normalization to apply to the audio data.
336+
butter: Butterworth | None
337+
Butterworth filter to apply to the audio data.
320338
321339
Returns
322340
-------
@@ -335,6 +353,7 @@ def from_files( # noqa: PLR0913
335353
mode=mode,
336354
overlap=overlap,
337355
data_duration=data_duration,
356+
butter=butter,
338357
)
339358

340359
@classmethod

src/osekit/public/project.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def prepare_audio(self, transform: Transform) -> AudioDataset:
294294
mode=transform.mode,
295295
overlap=transform.overlap,
296296
normalization=transform.normalization,
297+
butter=transform.butter,
297298
name=transform.name,
298299
instrument=self.instrument,
299300
)

src/osekit/public/transform.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Flag, auto
44
from typing import TYPE_CHECKING, Literal
55

6-
from osekit.utils.audio import Normalization
6+
from osekit.utils.audio import Butterworth, Normalization
77

88
if TYPE_CHECKING:
99
from pandas import Timedelta, Timestamp
@@ -73,6 +73,7 @@ def __init__(
7373
overlap: float = 0.0,
7474
sample_rate: float | None = None,
7575
normalization: Normalization = Normalization.RAW,
76+
butter: Butterworth | None = None,
7677
name: str | None = None,
7778
subtype: str | None = None,
7879
fft: ShortTimeFFT | None = None,
@@ -118,6 +119,8 @@ def __init__(
118119
will be set to the one of the original dataset.
119120
normalization: Normalization
120121
The type of normalization to apply to the audio data.
122+
butter: Butterworth | None
123+
Butterworth filter to apply to the audio data.
121124
name: str | None
122125
Name of the transform dataset.
123126
Defaulted as the begin timestamp of the transform dataset.
@@ -160,6 +163,7 @@ def __init__(
160163
self.sample_rate = sample_rate
161164
self.name = name
162165
self.normalization = normalization
166+
self.butter = butter
163167
self.subtype = subtype
164168
self.v_lim = v_lim
165169
self.colormap = colormap

src/osekit/utils/audio.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
import dataclasses
34
import enum
5+
from collections.abc import Iterable
46
from typing import Literal, Self
57

68
import numpy as np
79
import soxr
810
from pandas import Timedelta
11+
from scipy import signal
912

1013
from osekit.config import (
1114
resample_quality_settings,
@@ -203,3 +206,101 @@ def normalize(
203206
if Normalization.ZSCORE in normalization:
204207
values = normalize_zscore(values=values, mean=mean, std=std)
205208
return values
209+
210+
211+
@dataclasses.dataclass
212+
class Butterworth:
213+
"""Class that represent a Butterworth sos filter.
214+
215+
Parameters
216+
----------
217+
N: int
218+
The order of the filter.
219+
For "bandpass" and "bandstop" filters, the resulting order of the final
220+
second-order sections ("sos") matrix is ``2*N``,
221+
with ``N`` the number of biquad sections of the desired system.
222+
Wn: Iterable | int | float
223+
The critical frequency or frequencies.
224+
For lowpass and highpass filters, ``Wn`` is a scalar.
225+
For bandpass and bandstop filters, ``Wn`` is a length-2 sequence.
226+
For a Butterworth filter, this is the point at which the gain
227+
drops to ``1/sqrt(2)`` that of the passband (the “-3 dB point”).
228+
For digital filters, if ``fs`` is not specified,
229+
``Wn`` units are normalized from ``0`` to ``1``,
230+
where ``1`` is the Nyquist frequency
231+
(``Wn`` is thus in half cycles / sample and defined as
232+
``2*critical frequencies / fs``).
233+
If ``fs`` is specified, ``Wn`` is in the same units as ``fs``.
234+
For analog filters, ``Wn`` is an angular frequency (e.g. ``rad/s``).
235+
btype: Literal["lowpass", "highpass", "bandpass", "bandstop"]
236+
The type of filter. Default is "lowpass".
237+
238+
"""
239+
240+
N: int
241+
Wn: Iterable | int | float
242+
btype: Literal["lowpass", "highpass", "bandpass", "bandstop"] = "lowpass"
243+
244+
def to_dict(self) -> dict:
245+
"""Serialize a Butterworth sos filter to a dictionary.
246+
247+
Returns
248+
-------
249+
dict:
250+
Serialized Butterworth sos filter.
251+
252+
"""
253+
return {
254+
"N": self.N,
255+
"Wn": self.Wn,
256+
"btype": self.btype,
257+
}
258+
259+
@classmethod
260+
def from_dict(cls, data: dict) -> Butterworth:
261+
"""Deserialize a Butterworth sos filter from a dictionary.
262+
263+
Parameters
264+
----------
265+
data: dict
266+
Serialized Butterworth sos filter.
267+
268+
Returns
269+
-------
270+
Butterworth:
271+
The Butterworth sos filter.
272+
273+
"""
274+
return cls(
275+
N=data["N"],
276+
Wn=data["Wn"],
277+
btype=data["btype"],
278+
)
279+
280+
def filter(self, sig: np.typing.NDArray, fs: float) -> np.typing.NDArray:
281+
"""Filter an input signal with the Butterworth sos filter.
282+
283+
Parameters
284+
----------
285+
sig: np.typing.NDArray
286+
Input signal
287+
fs: float
288+
Sampling frequency of the signal
289+
290+
Returns
291+
-------
292+
np.typing.NDArray
293+
Filtered signal
294+
295+
"""
296+
sos = signal.butter(
297+
N=self.N,
298+
Wn=self.Wn,
299+
btype=self.btype,
300+
fs=fs,
301+
output="sos",
302+
)
303+
return signal.sosfilt(sos=sos, x=sig, axis=0)
304+
305+
def __hash__(self) -> int:
306+
return hash((self.N, self.Wn, self.btype))

0 commit comments

Comments
 (0)