Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/osekit/core_api/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def from_files( # noqa: PLR0913
begin = min(file.begin for file in files)
if not end:
end = max(file.end for file in files)
if begin > end:
msg = (f"`begin` ({begin}) must be smaller than `end`({end})")
raise ValueError(msg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be if begin >= end, as the error message specifies must be smaller than? (And that a 0s-long dataset wouldn't make much sense)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ron-swanson-done

if data_duration:
data_base = (
cls._get_base_data_from_files_timedelta_total(
Expand Down
41 changes: 41 additions & 0 deletions tests/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,8 @@ def test_audio_dataset_from_files(
"corrupted_audio_files",
"non_audio_files",
"error",
"begin",
"end",
),
[
pytest.param(
Expand All @@ -1056,6 +1058,8 @@ def test_audio_dataset_from_files(
FileNotFoundError,
match="No valid file found in ",
),
None,
None,
id="no_file",
),
pytest.param(
Expand All @@ -1076,6 +1080,8 @@ def test_audio_dataset_from_files(
FileNotFoundError,
match="No valid file found in ",
),
None,
None,
id="corrupted_audio_files",
),
pytest.param(
Expand Down Expand Up @@ -1103,6 +1109,8 @@ def test_audio_dataset_from_files(
],
[],
None,
None,
None,
id="mixed_audio_files",
),
pytest.param(
Expand All @@ -1126,6 +1134,8 @@ def test_audio_dataset_from_files(
+ ".csv",
],
None,
None,
None,
id="non_audio_files_are_not_logged",
),
pytest.param(
Expand All @@ -1151,6 +1161,8 @@ def test_audio_dataset_from_files(
FileNotFoundError,
match="No valid file found in ",
),
None,
None,
id="all_but_ok_audio",
),
pytest.param(
Expand Down Expand Up @@ -1183,15 +1195,42 @@ def test_audio_dataset_from_files(
+ ".csv",
],
None,
None,
None,
id="full_mix",
),
pytest.param(
{
"duration": 1,
"sample_rate": 48_000,
"nb_files": 3,
"date_begin": pd.Timestamp("2024-01-01 12:00:00"),
"series_type": "increase",
},
generate_sample_audio(
nb_files=1,
nb_samples=144_000,
series_type="increase",
),
[],
[],
pytest.raises(
ValueError,
match=r"`begin` .* must be smaller than `end`",
),
pd.Timestamp("2024-01-01 12:01:00"),
pd.Timestamp("2024-01-01 12:00:00"),
id="datetime_mismatch",
),
],
indirect=["audio_files"],
)
def test_audio_dataset_from_folder_errors_warnings(
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
audio_files: tuple[list[Path], pytest.fixtures.Subrequest],
begin: pd.Timestamp | None,
end: pd.Timestamp | None,
expected_audio_data: list[np.ndarray],
corrupted_audio_files: list[str],
non_audio_files: list[str],
Expand All @@ -1207,6 +1246,8 @@ def test_audio_dataset_from_folder_errors_warnings(
AudioDataset.from_folder(
tmp_path,
strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED,
begin=begin,
end=end,
)
== e
)
Expand Down