Skip to content

Commit 1977e7a

Browse files
Event begin and end validation (#275)
* test basedataset datetime mismatch * implement begin/end condition on event class * update event and test_event --------- Co-authored-by: Gautzilla <72027971+Gautzilla@users.noreply.github.com>
1 parent e1627eb commit 1977e7a

4 files changed

Lines changed: 189 additions & 7 deletions

File tree

src/osekit/core_api/base_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def begin(self, value: Timestamp) -> None:
103103
Begin can only be set to a posterior date from the original begin.
104104
105105
"""
106+
self.items = [item for item in self.items if item.end >= value]
106107
for item in self.items:
107108
item.begin = max(item.begin, value)
108109

@@ -118,6 +119,7 @@ def end(self) -> Timestamp:
118119
@end.setter
119120
def end(self, value: Timestamp) -> None:
120121
"""Return true if every item of this data object is empty."""
122+
self.items = [item for item in self.items if item.begin < value]
121123
for item in self.items:
122124
item.end = min(item.end, value)
123125

src/osekit/core_api/event.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import bisect
66
import copy
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
88
from typing import TYPE_CHECKING, TypeVar
99

1010
if TYPE_CHECKING:
@@ -15,15 +15,48 @@
1515
class Event:
1616
"""Events are bounded between begin an end attributes.
1717
18-
Classes that have a begin and an end should inherit from Event.
18+
Classes that have a beginning and an end should inherit from Event.
1919
"""
2020

21-
begin: Timestamp
22-
end: Timestamp
21+
_begin: Timestamp = field(init=False, repr=False, compare=True)
22+
_end: Timestamp = field(init=False, repr=False, compare=True)
23+
24+
def __init__(
25+
self,
26+
begin: Timestamp,
27+
end: Timestamp
28+
) -> None:
29+
"""Initialize an Event instance with a beginning and an end."""
30+
self.begin = begin
31+
self.end = end
32+
33+
@property
34+
def begin(self) -> Timestamp:
35+
"""Beginning of the event."""
36+
return self._begin
37+
38+
@begin.setter
39+
def begin(self, value: Timestamp) -> None:
40+
if hasattr(self, "_end") and value >= self._end:
41+
msg = f"Invalid Event: `end` ({self._end}) must be greater than `begin` ({value})." # noqa: E501
42+
raise ValueError(msg)
43+
self._begin = value
44+
45+
@property
46+
def end(self) -> Timestamp:
47+
"""End of the event."""
48+
return self._end
49+
50+
@end.setter
51+
def end(self, value: Timestamp) -> None:
52+
if hasattr(self, "_begin") and value <= self._begin:
53+
msg = f"Invalid Event: `end` ({value}) must be greater than `begin` ({self._begin})." # noqa: E501
54+
raise ValueError(msg)
55+
self._end = value
2356

2457
@property
2558
def duration(self) -> Timedelta:
26-
"""Return the total duration of the data in seconds."""
59+
"""Duration of the event."""
2760
return self.end - self.begin
2861

2962
def overlaps(self, other: type[Event] | Event) -> bool:
@@ -115,8 +148,8 @@ def remove_overlaps(cls, events: list[TEvent]) -> list[TEvent]:
115148
116149
""" # noqa: E501
117150
events = sorted(
118-
[copy.copy(event) for event in events],
119-
key=lambda event: (event.begin, event.end),
151+
events,
152+
key=lambda event: (event.begin, -1*event.duration),
120153
)
121154
concatenated_events = []
122155
for event in events:

tests/test_audio.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,8 @@ def test_audio_dataset_from_files(
11081108
"corrupted_audio_files",
11091109
"non_audio_files",
11101110
"error",
1111+
"begin",
1112+
"end",
11111113
),
11121114
[
11131115
pytest.param(
@@ -1119,6 +1121,8 @@ def test_audio_dataset_from_files(
11191121
FileNotFoundError,
11201122
match="No valid file found in ",
11211123
),
1124+
None,
1125+
None,
11221126
id="no_file",
11231127
),
11241128
pytest.param(
@@ -1139,6 +1143,8 @@ def test_audio_dataset_from_files(
11391143
FileNotFoundError,
11401144
match="No valid file found in ",
11411145
),
1146+
None,
1147+
None,
11421148
id="corrupted_audio_files",
11431149
),
11441150
pytest.param(
@@ -1166,6 +1172,8 @@ def test_audio_dataset_from_files(
11661172
],
11671173
[],
11681174
None,
1175+
None,
1176+
None,
11691177
id="mixed_audio_files",
11701178
),
11711179
pytest.param(
@@ -1189,6 +1197,8 @@ def test_audio_dataset_from_files(
11891197
+ ".csv",
11901198
],
11911199
None,
1200+
None,
1201+
None,
11921202
id="non_audio_files_are_not_logged",
11931203
),
11941204
pytest.param(
@@ -1214,6 +1224,8 @@ def test_audio_dataset_from_files(
12141224
FileNotFoundError,
12151225
match="No valid file found in ",
12161226
),
1227+
None,
1228+
None,
12171229
id="all_but_ok_audio",
12181230
),
12191231
pytest.param(
@@ -1246,15 +1258,42 @@ def test_audio_dataset_from_files(
12461258
+ ".csv",
12471259
],
12481260
None,
1261+
None,
1262+
None,
12491263
id="full_mix",
12501264
),
1265+
pytest.param(
1266+
{
1267+
"duration": 1,
1268+
"sample_rate": 48_000,
1269+
"nb_files": 3,
1270+
"date_begin": pd.Timestamp("2024-01-01 12:00:00"),
1271+
"series_type": "increase",
1272+
},
1273+
generate_sample_audio(
1274+
nb_files=1,
1275+
nb_samples=144_000,
1276+
series_type="increase",
1277+
),
1278+
[],
1279+
[],
1280+
pytest.raises(
1281+
ValueError,
1282+
match=r"`end` .* must be greater than `begin`",
1283+
),
1284+
pd.Timestamp("2024-01-01 12:01:00"),
1285+
pd.Timestamp("2024-01-01 12:00:00"),
1286+
id="datetime_mismatch",
1287+
),
12511288
],
12521289
indirect=["audio_files"],
12531290
)
12541291
def test_audio_dataset_from_folder_errors_warnings(
12551292
tmp_path: Path,
12561293
caplog: pytest.LogCaptureFixture,
12571294
audio_files: tuple[list[Path], pytest.fixtures.Subrequest],
1295+
begin: pd.Timestamp | None,
1296+
end: pd.Timestamp | None,
12581297
expected_audio_data: list[np.ndarray],
12591298
corrupted_audio_files: list[str],
12601299
non_audio_files: list[str],
@@ -1270,6 +1309,8 @@ def test_audio_dataset_from_folder_errors_warnings(
12701309
AudioDataset.from_folder(
12711310
tmp_path,
12721311
strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED,
1312+
begin=begin,
1313+
end=end,
12731314
)
12741315
== e
12751316
)

tests/test_event.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from contextlib import nullcontext
4+
35
import pytest
46
from pandas import Timestamp
57

@@ -326,3 +328,107 @@ def test_get_overlapping_events(
326328
)
327329

328330
assert len(overlap_result) == len(expected_result)
331+
332+
333+
@pytest.mark.parametrize(
334+
("event", "updated_begin", "updated_end", "expected"),
335+
[
336+
pytest.param(
337+
Event(
338+
begin=Timestamp("2024-01-01 00:00:00"),
339+
end=Timestamp("2024-01-02 00:00:00"),
340+
),
341+
Timestamp("2024-01-01 12:00:00"),
342+
None,
343+
nullcontext(
344+
Event(
345+
begin=Timestamp("2024-01-01 12:00:00"),
346+
end=Timestamp("2024-01-02 00:00:00"),
347+
)
348+
),
349+
id="valid_begin",
350+
),
351+
pytest.param(
352+
Event(
353+
begin=Timestamp("2024-01-01 00:00:00"),
354+
end=Timestamp("2024-01-02 00:00:00"),
355+
),
356+
None,
357+
Timestamp("2024-01-02 12:00:00"),
358+
nullcontext(
359+
Event(
360+
begin=Timestamp("2024-01-01 00:00:00"),
361+
end=Timestamp("2024-01-02 12:00:00"),
362+
)
363+
),
364+
id="valid_end",
365+
),
366+
pytest.param(
367+
Event(
368+
begin=Timestamp("2024-01-01 00:00:00"),
369+
end=Timestamp("2024-01-02 00:00:00"),
370+
),
371+
Timestamp("2024-01-03 00:00:00"),
372+
None,
373+
pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"),
374+
id="invalid_begin_after_end",
375+
),
376+
pytest.param(
377+
Event(
378+
begin=Timestamp("2024-01-01 00:00:00"),
379+
end=Timestamp("2024-01-02 00:00:00"),
380+
),
381+
None,
382+
Timestamp("2023-12-31 23:59:59"),
383+
pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"),
384+
id="invalid_end_before_begin",
385+
),
386+
pytest.param(
387+
Event(
388+
begin=Timestamp("2024-01-01 00:00:00"),
389+
end=Timestamp("2024-01-01 01:00:00"),
390+
),
391+
Timestamp("2024-01-01 01:00:00"),
392+
None,
393+
pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"),
394+
id="begin_equals_end",
395+
),
396+
],
397+
)
398+
def test_event_begin_end_updates(
399+
event: Event,
400+
updated_begin: Timestamp | None,
401+
updated_end: Timestamp | None,
402+
expected: Event,
403+
) -> None:
404+
def update_event(
405+
cool_event: Event, begin: Timestamp | None, end: Timestamp | None
406+
) -> Event:
407+
if begin:
408+
cool_event.begin = begin
409+
if end:
410+
cool_event.end = end
411+
return cool_event
412+
413+
with expected as e:
414+
assert update_event(event, updated_begin, updated_end) == e
415+
416+
417+
@pytest.mark.parametrize(
418+
("begin", "end"),
419+
[
420+
pytest.param(
421+
Timestamp("2024-01-02 00:00:00"),
422+
Timestamp("2024-01-01 00:00:00"),
423+
id="begin_after_end",
424+
),
425+
pytest.param(
426+
Timestamp("2024-01-01 00:00:00"),
427+
Timestamp("2024-01-01 00:00:00"),
428+
id="begin_equals_end",
429+
),
430+
],
431+
)
432+
def test_event_errors(begin: Timestamp, end: Timestamp) -> None:
433+
with pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*") as e:
434+
assert Event(begin=begin, end=end) == e

0 commit comments

Comments
 (0)