|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from contextlib import nullcontext |
| 4 | + |
3 | 5 | import pytest |
4 | 6 | from pandas import Timestamp |
5 | 7 |
|
@@ -326,3 +328,107 @@ def test_get_overlapping_events( |
326 | 328 | ) |
327 | 329 |
|
328 | 330 | assert len(overlap_result) == len(expected_result) |
| 331 | + |
| 332 | + |
| 333 | +@pytest.mark.parametrize( |
| 334 | + ("event", "updated_begin", "updated_end", "expected"), |
| 335 | + [ |
| 336 | + pytest.param( |
| 337 | + Event( |
| 338 | + begin=Timestamp("2024-01-01 00:00:00"), |
| 339 | + end=Timestamp("2024-01-02 00:00:00"), |
| 340 | + ), |
| 341 | + Timestamp("2024-01-01 12:00:00"), |
| 342 | + None, |
| 343 | + nullcontext( |
| 344 | + Event( |
| 345 | + begin=Timestamp("2024-01-01 12:00:00"), |
| 346 | + end=Timestamp("2024-01-02 00:00:00"), |
| 347 | + ) |
| 348 | + ), |
| 349 | + id="valid_begin", |
| 350 | + ), |
| 351 | + pytest.param( |
| 352 | + Event( |
| 353 | + begin=Timestamp("2024-01-01 00:00:00"), |
| 354 | + end=Timestamp("2024-01-02 00:00:00"), |
| 355 | + ), |
| 356 | + None, |
| 357 | + Timestamp("2024-01-02 12:00:00"), |
| 358 | + nullcontext( |
| 359 | + Event( |
| 360 | + begin=Timestamp("2024-01-01 00:00:00"), |
| 361 | + end=Timestamp("2024-01-02 12:00:00"), |
| 362 | + ) |
| 363 | + ), |
| 364 | + id="valid_end", |
| 365 | + ), |
| 366 | + pytest.param( |
| 367 | + Event( |
| 368 | + begin=Timestamp("2024-01-01 00:00:00"), |
| 369 | + end=Timestamp("2024-01-02 00:00:00"), |
| 370 | + ), |
| 371 | + Timestamp("2024-01-03 00:00:00"), |
| 372 | + None, |
| 373 | + pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"), |
| 374 | + id="invalid_begin_after_end", |
| 375 | + ), |
| 376 | + pytest.param( |
| 377 | + Event( |
| 378 | + begin=Timestamp("2024-01-01 00:00:00"), |
| 379 | + end=Timestamp("2024-01-02 00:00:00"), |
| 380 | + ), |
| 381 | + None, |
| 382 | + Timestamp("2023-12-31 23:59:59"), |
| 383 | + pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"), |
| 384 | + id="invalid_end_before_begin", |
| 385 | + ), |
| 386 | + pytest.param( |
| 387 | + Event( |
| 388 | + begin=Timestamp("2024-01-01 00:00:00"), |
| 389 | + end=Timestamp("2024-01-01 01:00:00"), |
| 390 | + ), |
| 391 | + Timestamp("2024-01-01 01:00:00"), |
| 392 | + None, |
| 393 | + pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*"), |
| 394 | + id="begin_equals_end", |
| 395 | + ), |
| 396 | + ], |
| 397 | +) |
| 398 | +def test_event_begin_end_updates( |
| 399 | + event: Event, |
| 400 | + updated_begin: Timestamp | None, |
| 401 | + updated_end: Timestamp | None, |
| 402 | + expected: Event, |
| 403 | +) -> None: |
| 404 | + def update_event( |
| 405 | + cool_event: Event, begin: Timestamp | None, end: Timestamp | None |
| 406 | + ) -> Event: |
| 407 | + if begin: |
| 408 | + cool_event.begin = begin |
| 409 | + if end: |
| 410 | + cool_event.end = end |
| 411 | + return cool_event |
| 412 | + |
| 413 | + with expected as e: |
| 414 | + assert update_event(event, updated_begin, updated_end) == e |
| 415 | + |
| 416 | + |
| 417 | +@pytest.mark.parametrize( |
| 418 | + ("begin", "end"), |
| 419 | + [ |
| 420 | + pytest.param( |
| 421 | + Timestamp("2024-01-02 00:00:00"), |
| 422 | + Timestamp("2024-01-01 00:00:00"), |
| 423 | + id="begin_after_end", |
| 424 | + ), |
| 425 | + pytest.param( |
| 426 | + Timestamp("2024-01-01 00:00:00"), |
| 427 | + Timestamp("2024-01-01 00:00:00"), |
| 428 | + id="begin_equals_end", |
| 429 | + ), |
| 430 | + ], |
| 431 | +) |
| 432 | +def test_event_errors(begin: Timestamp, end: Timestamp) -> None: |
| 433 | + with pytest.raises(ValueError, match="`end`.*must be greater than `begin`.*") as e: |
| 434 | + assert Event(begin=begin, end=end) == e |
0 commit comments