diff --git a/src/osekit/core_api/base_data.py b/src/osekit/core_api/base_data.py index c9887215..4cde981d 100644 --- a/src/osekit/core_api/base_data.py +++ b/src/osekit/core_api/base_data.py @@ -103,6 +103,7 @@ def begin(self, value: Timestamp) -> None: Begin can only be set to a posterior date from the original begin. """ + self.items = [item for item in self.items if item.end >= value] for item in self.items: item.begin = max(item.begin, value) @@ -118,6 +119,7 @@ def end(self) -> Timestamp: @end.setter def end(self, value: Timestamp) -> None: """Return true if every item of this data object is empty.""" + self.items = [item for item in self.items if item.begin < value] for item in self.items: item.end = min(item.end, value) diff --git a/src/osekit/core_api/event.py b/src/osekit/core_api/event.py index 44c206c5..4f9d5764 100644 --- a/src/osekit/core_api/event.py +++ b/src/osekit/core_api/event.py @@ -4,7 +4,7 @@ import bisect import copy -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, TypeVar if TYPE_CHECKING: @@ -15,15 +15,48 @@ class Event: """Events are bounded between begin an end attributes. - Classes that have a begin and an end should inherit from Event. + Classes that have a beginning and an end should inherit from Event. """ - begin: Timestamp - end: Timestamp + _begin: Timestamp = field(init=False, repr=False, compare=True) + _end: Timestamp = field(init=False, repr=False, compare=True) + + def __init__( + self, + begin: Timestamp, + end: Timestamp + ) -> None: + """Initialize an Event instance with a beginning and an end.""" + self.begin = begin + self.end = end + + @property + def begin(self) -> Timestamp: + """Beginning of the event.""" + return self._begin + + @begin.setter + def begin(self, value: Timestamp) -> None: + if hasattr(self, "_end") and value >= self._end: + msg = f"Invalid Event: `end` ({self._end}) must be greater than `begin` ({value})." # noqa: E501 + raise ValueError(msg) + self._begin = value + + @property + def end(self) -> Timestamp: + """End of the event.""" + return self._end + + @end.setter + def end(self, value: Timestamp) -> None: + if hasattr(self, "_begin") and value <= self._begin: + msg = f"Invalid Event: `end` ({value}) must be greater than `begin` ({self._begin})." # noqa: E501 + raise ValueError(msg) + self._end = value @property def duration(self) -> Timedelta: - """Return the total duration of the data in seconds.""" + """Duration of the event.""" return self.end - self.begin def overlaps(self, other: type[Event] | Event) -> bool: @@ -115,8 +148,8 @@ def remove_overlaps(cls, events: list[TEvent]) -> list[TEvent]: """ # noqa: E501 events = sorted( - [copy.copy(event) for event in events], - key=lambda event: (event.begin, event.end), + events, + key=lambda event: (event.begin, -1*event.duration), ) concatenated_events = [] for event in events: diff --git a/tests/test_audio.py b/tests/test_audio.py index 9e67d4b1..b50ab524 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -1108,6 +1108,8 @@ def test_audio_dataset_from_files( "corrupted_audio_files", "non_audio_files", "error", + "begin", + "end", ), [ pytest.param( @@ -1119,6 +1121,8 @@ def test_audio_dataset_from_files( FileNotFoundError, match="No valid file found in ", ), + None, + None, id="no_file", ), pytest.param( @@ -1139,6 +1143,8 @@ def test_audio_dataset_from_files( FileNotFoundError, match="No valid file found in ", ), + None, + None, id="corrupted_audio_files", ), pytest.param( @@ -1166,6 +1172,8 @@ def test_audio_dataset_from_files( ], [], None, + None, + None, id="mixed_audio_files", ), pytest.param( @@ -1189,6 +1197,8 @@ def test_audio_dataset_from_files( + ".csv", ], None, + None, + None, id="non_audio_files_are_not_logged", ), pytest.param( @@ -1214,6 +1224,8 @@ def test_audio_dataset_from_files( FileNotFoundError, match="No valid file found in ", ), + None, + None, id="all_but_ok_audio", ), pytest.param( @@ -1246,8 +1258,33 @@ def test_audio_dataset_from_files( + ".csv", ], None, + None, + None, id="full_mix", ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 3, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + generate_sample_audio( + nb_files=1, + nb_samples=144_000, + series_type="increase", + ), + [], + [], + pytest.raises( + ValueError, + match=r"`end` .* must be greater than `begin`", + ), + pd.Timestamp("2024-01-01 12:01:00"), + pd.Timestamp("2024-01-01 12:00:00"), + id="datetime_mismatch", + ), ], indirect=["audio_files"], ) @@ -1255,6 +1292,8 @@ def test_audio_dataset_from_folder_errors_warnings( tmp_path: Path, caplog: pytest.LogCaptureFixture, audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + begin: pd.Timestamp | None, + end: pd.Timestamp | None, expected_audio_data: list[np.ndarray], corrupted_audio_files: list[str], non_audio_files: list[str], @@ -1270,6 +1309,8 @@ def test_audio_dataset_from_folder_errors_warnings( AudioDataset.from_folder( tmp_path, strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED, + begin=begin, + end=end, ) == e ) diff --git a/tests/test_event.py b/tests/test_event.py index 9afd2350..c9e631bf 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -1,5 +1,7 @@ from __future__ import annotations +from contextlib import nullcontext + import pytest from pandas import Timestamp @@ -326,3 +328,107 @@ def test_get_overlapping_events( ) assert len(overlap_result) == len(expected_result) + + +@pytest.mark.parametrize( + ("event", "updated_begin", "updated_end", "expected"), + [ + pytest.param( + Event( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ), + Timestamp("2024-01-01 12:00:00"), + None, + nullcontext( + Event( + begin=Timestamp("2024-01-01 12:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ) + ), + id="valid_begin", + ), + pytest.param( + Event( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ), + None, + Timestamp("2024-01-02 12:00:00"), + nullcontext( + Event( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-02 12:00:00"), + ) + ), + id="valid_end", + ), + pytest.param( + Event( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ), + Timestamp("2024-01-03 00:00:00"), + None, + pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"), + id="invalid_begin_after_end", + ), + pytest.param( + Event( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ), + None, + Timestamp("2023-12-31 23:59:59"), + pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"), + id="invalid_end_before_begin", + ), + pytest.param( + Event( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-01 01:00:00"), + ), + Timestamp("2024-01-01 01:00:00"), + None, + pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"), + id="begin_equals_end", + ), + ], +) +def test_event_begin_end_updates( + event: Event, + updated_begin: Timestamp | None, + updated_end: Timestamp | None, + expected: Event, +) -> None: + def update_event( + cool_event: Event, begin: Timestamp | None, end: Timestamp | None + ) -> Event: + if begin: + cool_event.begin = begin + if end: + cool_event.end = end + return cool_event + + with expected as e: + assert update_event(event, updated_begin, updated_end) == e + + +@pytest.mark.parametrize( + ("begin", "end"), + [ + pytest.param( + Timestamp("2024-01-02 00:00:00"), + Timestamp("2024-01-01 00:00:00"), + id="begin_after_end", + ), + pytest.param( + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-01 00:00:00"), + id="begin_equals_end", + ), + ], +) +def test_event_errors(begin: Timestamp, end: Timestamp) -> None: + with pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*") as e: + assert Event(begin=begin, end=end) == e