Skip to content

Commit 46845ca

Browse files
committed
implement begin/end condition on event class
1 parent f6ac7cf commit 46845ca

5 files changed

Lines changed: 138 additions & 9 deletions

File tree

src/osekit/core_api/base_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def begin(self, value: Timestamp) -> None:
103103
Begin can only be set to a posterior date from the original begin.
104104
105105
"""
106+
self.items = [item for item in self.items if item.end >= value]
106107
for item in self.items:
107108
item.begin = max(item.begin, value)
108109

@@ -118,6 +119,7 @@ def end(self) -> Timestamp:
118119
@end.setter
119120
def end(self, value: Timestamp) -> None:
120121
"""Return true if every item of this data object is empty."""
122+
self.items = [item for item in self.items if item.begin < value]
121123
for item in self.items:
122124
item.end = min(item.end, value)
123125

src/osekit/core_api/base_dataset.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,6 @@ def from_files( # noqa: PLR0913
299299
begin = min(file.begin for file in files)
300300
if not end:
301301
end = max(file.end for file in files)
302-
if begin >= end:
303-
msg = (f"`begin` ({begin}) must be smaller than `end`({end})")
304-
raise ValueError(msg)
305302
if data_duration:
306303
data_base = (
307304
cls._get_base_data_from_files_timedelta_total(

src/osekit/core_api/event.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import bisect
66
import copy
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
88
from typing import TYPE_CHECKING, TypeVar
99

1010
if TYPE_CHECKING:
@@ -15,15 +15,51 @@
1515
class Event:
1616
"""Events are bounded between begin an end attributes.
1717
18-
Classes that have a begin and an end should inherit from Event.
18+
Classes that have a beginning and an end should inherit from Event.
1919
"""
2020

21-
begin: Timestamp
22-
end: Timestamp
21+
_begin: Timestamp = field(init=False, repr=False, compare=True)
22+
_end: Timestamp = field(init=False, repr=False, compare=True)
23+
24+
def __init__(
25+
self,
26+
begin: Timestamp,
27+
end: Timestamp
28+
) -> None:
29+
"""Initialize an Event instance with a beginning and an end."""
30+
if begin > end:
31+
msg = f"Invalid Event: `end` ({end}) must be greater than `begin` ({begin})." # noqa: E501
32+
raise ValueError(msg)
33+
self._begin = begin
34+
self._end = end
35+
36+
@property
37+
def begin(self) -> Timestamp:
38+
"""Beginning of the event."""
39+
return self._begin
40+
41+
@begin.setter
42+
def begin(self, value: Timestamp) -> None:
43+
if hasattr(self, "_end") and value > self._end:
44+
msg = f"Invalid Event: `end` ({self._end}) must be greater than `begin` ({value})." # noqa: E501
45+
raise ValueError(msg)
46+
self._begin = value
47+
48+
@property
49+
def end(self) -> Timestamp:
50+
"""End of the event."""
51+
return self._end
52+
53+
@end.setter
54+
def end(self, value: Timestamp) -> None:
55+
if hasattr(self, "_begin") and value < self._begin:
56+
msg = f"Invalid Event: `end` ({value}) must be greater than `begin` ({self._begin})." # noqa: E501
57+
raise ValueError(msg)
58+
self._end = value
2359

2460
@property
2561
def duration(self) -> Timedelta:
26-
"""Return the total duration of the data in seconds."""
62+
"""Duration of the event."""
2763
return self.end - self.begin
2864

2965
def overlaps(self, other: type[Event] | Event) -> bool:

tests/test_audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,7 @@ def test_audio_dataset_from_files(
12161216
[],
12171217
pytest.raises(
12181218
ValueError,
1219-
match=r"`begin` .* must be smaller than `end`",
1219+
match=r"`end` .* must be greater than `begin`",
12201220
),
12211221
pd.Timestamp("2024-01-01 12:01:00"),
12221222
pd.Timestamp("2024-01-01 12:00:00"),

tests/test_event.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,97 @@ def test_get_overlapping_events(
326326
)
327327

328328
assert len(overlap_result) == len(expected_result)
329+
330+
331+
@pytest.mark.parametrize(
332+
("initial", "updates", "expected_values", "error_at"),
333+
[
334+
pytest.param(
335+
[
336+
("begin", Timestamp("2024-01-01 00:00:00")),
337+
("end", Timestamp("2024-01-02 00:00:00")),
338+
],
339+
[
340+
("begin", Timestamp("2024-01-01 12:00:00")),
341+
("end", Timestamp("2024-01-02 12:00:00")),
342+
],
343+
(Timestamp("2024-01-01 12:00:00"), Timestamp("2024-01-02 12:00:00")),
344+
None,
345+
id="valid sequential updates",
346+
),
347+
pytest.param(
348+
[
349+
("begin", Timestamp("2024-01-01 00:00:00")),
350+
("end", Timestamp("2024-01-02 00:00:00")),
351+
],
352+
[("begin", Timestamp("2024-01-03 00:00:00"))],
353+
(Timestamp("2024-01-01 00:00:00"), Timestamp("2024-01-02 00:00:00")),
354+
0,
355+
id="invalid begin > end",
356+
),
357+
pytest.param(
358+
[
359+
("begin", Timestamp("2024-01-01 00:00:00")),
360+
("end", Timestamp("2024-01-02 00:00:00")),
361+
],
362+
[("end", Timestamp("2023-12-31 23:59:59"))],
363+
(Timestamp("2024-01-01 00:00:00"), Timestamp("2024-01-02 00:00:00")),
364+
0,
365+
id="invalid end < begin",
366+
),
367+
pytest.param(
368+
[
369+
("begin", Timestamp("2024-01-01 00:00:00")),
370+
("end", Timestamp("2024-01-02 00:00:00")),
371+
],
372+
[
373+
("begin", Timestamp("2024-01-01 12:00:00")),
374+
("end", Timestamp("2024-01-01 06:00:00")),
375+
("end", Timestamp("2024-01-02 12:00:00")),
376+
],
377+
(Timestamp("2024-01-01 12:00:00"), Timestamp("2024-01-02 12:00:00")),
378+
1,
379+
id="mixed valid and invalid updates",
380+
),
381+
pytest.param(
382+
[
383+
("begin", Timestamp("2024-01-01 00:00:00")),
384+
("end", Timestamp("2024-01-01 01:00:00")),
385+
],
386+
[("begin", Timestamp("2024-01-01 01:00:00"))],
387+
(Timestamp("2024-01-01 01:00:00"), Timestamp("2024-01-01 01:00:00")),
388+
None,
389+
id="begin equals end edge",
390+
),
391+
pytest.param(
392+
[
393+
("begin", Timestamp("2024-01-01 00:00:00")),
394+
("end", Timestamp("2024-01-01 01:00:00")),
395+
],
396+
[("end", Timestamp("2024-01-01 00:00:00"))],
397+
(Timestamp("2024-01-01 00:00:00"), Timestamp("2024-01-01 00:00:00")),
398+
None,
399+
id="end equals begin edge",
400+
),
401+
],
402+
)
403+
404+
def test_event_begin_end_updates(
405+
initial: list[tuple[str, Timestamp]],
406+
updates: list[tuple[str, Timestamp]],
407+
expected_values: tuple[Timestamp, Timestamp],
408+
error_at: int | None,
409+
) -> None:
410+
initial_dict = dict(initial)
411+
cool_event = Event(begin=initial_dict["begin"], end=initial_dict["end"])
412+
413+
for i, (attr, value) in enumerate(updates):
414+
if error_at is not None and i == error_at:
415+
with pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"):
416+
setattr(cool_event, attr, value)
417+
else:
418+
setattr(cool_event, attr, value)
419+
420+
assert cool_event.begin == expected_values[0]
421+
assert cool_event.end == expected_values[1]
422+

0 commit comments

Comments
 (0)