Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/osekit/core_api/base_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def begin(self, value: Timestamp) -> None:
Begin can only be set to a posterior date from the original begin.

"""
self.items = [item for item in self.items if item.end >= value]
for item in self.items:
item.begin = max(item.begin, value)

Expand All @@ -118,6 +119,7 @@ def end(self) -> Timestamp:
@end.setter
def end(self, value: Timestamp) -> None:
"""Return true if every item of this data object is empty."""
self.items = [item for item in self.items if item.begin < value]
for item in self.items:
item.end = min(item.end, value)

Expand Down
47 changes: 40 additions & 7 deletions src/osekit/core_api/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import bisect
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, TypeVar

if TYPE_CHECKING:
Expand All @@ -15,15 +15,48 @@
class Event:
"""Events are bounded between begin an end attributes.

Classes that have a begin and an end should inherit from Event.
Classes that have a beginning and an end should inherit from Event.
"""

begin: Timestamp
end: Timestamp
_begin: Timestamp = field(init=False, repr=False, compare=True)
_end: Timestamp = field(init=False, repr=False, compare=True)

def __init__(
self,
begin: Timestamp,
end: Timestamp
) -> None:
"""Initialize an Event instance with a beginning and an end."""
self.begin = begin
self.end = end

@property
def begin(self) -> Timestamp:
"""Beginning of the event."""
return self._begin

@begin.setter
def begin(self, value: Timestamp) -> None:
if hasattr(self, "_end") and value >= self._end:
msg = f"Invalid Event: `end` ({self._end}) must be greater than `begin` ({value})." # noqa: E501
raise ValueError(msg)
self._begin = value

@property
def end(self) -> Timestamp:
"""End of the event."""
return self._end

@end.setter
def end(self, value: Timestamp) -> None:
if hasattr(self, "_begin") and value <= self._begin:
msg = f"Invalid Event: `end` ({value}) must be greater than `begin` ({self._begin})." # noqa: E501
raise ValueError(msg)
self._end = value

@property
def duration(self) -> Timedelta:
"""Return the total duration of the data in seconds."""
"""Duration of the event."""
return self.end - self.begin

def overlaps(self, other: type[Event] | Event) -> bool:
Expand Down Expand Up @@ -115,8 +148,8 @@ def remove_overlaps(cls, events: list[TEvent]) -> list[TEvent]:

""" # noqa: E501
events = sorted(
[copy.copy(event) for event in events],
key=lambda event: (event.begin, event.end),
events,
key=lambda event: (event.begin, -1*event.duration),
)
concatenated_events = []
for event in events:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,8 @@ def test_audio_dataset_from_files(
"corrupted_audio_files",
"non_audio_files",
"error",
"begin",
"end",
),
[
pytest.param(
Expand All @@ -1119,6 +1121,8 @@ def test_audio_dataset_from_files(
FileNotFoundError,
match="No valid file found in ",
),
None,
None,
id="no_file",
),
pytest.param(
Expand All @@ -1139,6 +1143,8 @@ def test_audio_dataset_from_files(
FileNotFoundError,
match="No valid file found in ",
),
None,
None,
id="corrupted_audio_files",
),
pytest.param(
Expand Down Expand Up @@ -1166,6 +1172,8 @@ def test_audio_dataset_from_files(
],
[],
None,
None,
None,
id="mixed_audio_files",
),
pytest.param(
Expand All @@ -1189,6 +1197,8 @@ def test_audio_dataset_from_files(
+ ".csv",
],
None,
None,
None,
id="non_audio_files_are_not_logged",
),
pytest.param(
Expand All @@ -1214,6 +1224,8 @@ def test_audio_dataset_from_files(
FileNotFoundError,
match="No valid file found in ",
),
None,
None,
id="all_but_ok_audio",
),
pytest.param(
Expand Down Expand Up @@ -1246,15 +1258,42 @@ def test_audio_dataset_from_files(
+ ".csv",
],
None,
None,
None,
id="full_mix",
),
pytest.param(
{
"duration": 1,
"sample_rate": 48_000,
"nb_files": 3,
"date_begin": pd.Timestamp("2024-01-01 12:00:00"),
"series_type": "increase",
},
generate_sample_audio(
nb_files=1,
nb_samples=144_000,
series_type="increase",
),
[],
[],
pytest.raises(
ValueError,
match=r"`end` .* must be greater than `begin`",
),
pd.Timestamp("2024-01-01 12:01:00"),
pd.Timestamp("2024-01-01 12:00:00"),
id="datetime_mismatch",
),
],
indirect=["audio_files"],
)
def test_audio_dataset_from_folder_errors_warnings(
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
audio_files: tuple[list[Path], pytest.fixtures.Subrequest],
begin: pd.Timestamp | None,
end: pd.Timestamp | None,
expected_audio_data: list[np.ndarray],
corrupted_audio_files: list[str],
non_audio_files: list[str],
Expand All @@ -1270,6 +1309,8 @@ def test_audio_dataset_from_folder_errors_warnings(
AudioDataset.from_folder(
tmp_path,
strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED,
begin=begin,
end=end,
)
== e
)
Expand Down
106 changes: 106 additions & 0 deletions tests/test_event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from contextlib import nullcontext

import pytest
from pandas import Timestamp

Expand Down Expand Up @@ -326,3 +328,107 @@ def test_get_overlapping_events(
)

assert len(overlap_result) == len(expected_result)


@pytest.mark.parametrize(
("event", "updated_begin", "updated_end", "expected"),
[
pytest.param(
Event(
begin=Timestamp("2024-01-01 00:00:00"),
end=Timestamp("2024-01-02 00:00:00"),
),
Timestamp("2024-01-01 12:00:00"),
None,
nullcontext(
Event(
begin=Timestamp("2024-01-01 12:00:00"),
end=Timestamp("2024-01-02 00:00:00"),
)
),
id="valid_begin",
),
pytest.param(
Event(
begin=Timestamp("2024-01-01 00:00:00"),
end=Timestamp("2024-01-02 00:00:00"),
),
None,
Timestamp("2024-01-02 12:00:00"),
nullcontext(
Event(
begin=Timestamp("2024-01-01 00:00:00"),
end=Timestamp("2024-01-02 12:00:00"),
)
),
id="valid_end",
),
pytest.param(
Event(
begin=Timestamp("2024-01-01 00:00:00"),
end=Timestamp("2024-01-02 00:00:00"),
),
Timestamp("2024-01-03 00:00:00"),
None,
pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"),
id="invalid_begin_after_end",
),
pytest.param(
Event(
begin=Timestamp("2024-01-01 00:00:00"),
end=Timestamp("2024-01-02 00:00:00"),
),
None,
Timestamp("2023-12-31 23:59:59"),
pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"),
id="invalid_end_before_begin",
),
pytest.param(
Event(
begin=Timestamp("2024-01-01 00:00:00"),
end=Timestamp("2024-01-01 01:00:00"),
),
Timestamp("2024-01-01 01:00:00"),
None,
pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"),
id="begin_equals_end",
),
],
)
def test_event_begin_end_updates(
event: Event,
updated_begin: Timestamp | None,
updated_end: Timestamp | None,
expected: Event,
) -> None:
def update_event(
cool_event: Event, begin: Timestamp | None, end: Timestamp | None
) -> Event:
if begin:
cool_event.begin = begin
if end:
cool_event.end = end
return cool_event

with expected as e:
assert update_event(event, updated_begin, updated_end) == e


@pytest.mark.parametrize(
("begin", "end"),
[
pytest.param(
Timestamp("2024-01-02 00:00:00"),
Timestamp("2024-01-01 00:00:00"),
id="begin_after_end",
),
pytest.param(
Timestamp("2024-01-01 00:00:00"),
Timestamp("2024-01-01 00:00:00"),
id="begin_equals_end",
),
],
)
def test_event_errors(begin: Timestamp, end: Timestamp) -> None:
with pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*") as e:
assert Event(begin=begin, end=end) == e