|
1 | 1 | """Synchronization primitives.""" |
2 | 2 |
|
3 | 3 | __all__ = ('Lock', 'Event', 'Condition', 'Semaphore', |
4 | | - 'BoundedSemaphore', 'Barrier') |
| 4 | + 'BoundedSemaphore', 'Barrier', |
| 5 | + 'CapacityLimiter', 'CapacityLimiterStatistics') |
5 | 6 |
|
6 | 7 | import collections |
| 8 | +import dataclasses |
7 | 9 | import enum |
| 10 | +import math |
8 | 11 |
|
9 | 12 | from . import exceptions |
10 | 13 | from . import mixins |
@@ -615,3 +618,181 @@ def n_waiting(self): |
615 | 618 | def broken(self): |
616 | 619 | """Return True if the barrier is in a broken state.""" |
617 | 620 | return self._state is _BarrierState.BROKEN |
| 621 | + |
| 622 | + |
| 623 | +@dataclasses.dataclass(frozen=True) |
| 624 | +class CapacityLimiterStatistics: |
| 625 | + """Statistics for a CapacityLimiter.""" |
| 626 | + borrowed_tokens: int |
| 627 | + total_tokens: int | float |
| 628 | + borrowers: tuple[object, ...] |
| 629 | + tasks_waiting: int |
| 630 | + |
| 631 | + |
| 632 | +class CapacityLimiter(_ContextManagerMixin, mixins._LoopBoundMixin): |
| 633 | + """A capacity limiter that tracks borrowers and supports dynamic capacity. |
| 634 | +
|
| 635 | + Unlike a Semaphore, a CapacityLimiter: |
| 636 | + - Tracks which tasks hold tokens, preventing the same task from |
| 637 | + acquiring twice (which would deadlock a semaphore). |
| 638 | + - Allows dynamic adjustment of total_tokens at runtime. |
| 639 | + - Supports acquiring/releasing on behalf of arbitrary objects. |
| 640 | +
|
| 641 | + Usage:: |
| 642 | +
|
| 643 | + limiter = CapacityLimiter(10) |
| 644 | +
|
| 645 | + async with limiter: |
| 646 | + # At most 10 tasks can be here concurrently |
| 647 | + ... |
| 648 | +
|
| 649 | + """ |
| 650 | + |
| 651 | + def __init__(self, total_tokens: int | float): |
| 652 | + self._validate_tokens(total_tokens) |
| 653 | + self._total_tokens: int | float = total_tokens |
| 654 | + self._borrowers: set[object] = set() |
| 655 | + self._waiters: collections.OrderedDict[object, object] = ( |
| 656 | + collections.OrderedDict() |
| 657 | + ) |
| 658 | + |
| 659 | + def __repr__(self): |
| 660 | + res = super().__repr__() |
| 661 | + extra = (f'borrowed:{self.borrowed_tokens}, ' |
| 662 | + f'total:{self._total_tokens}') |
| 663 | + if self._waiters: |
| 664 | + extra = f'{extra}, waiters:{len(self._waiters)}' |
| 665 | + return f'<{res[1:-1]} [{extra}]>' |
| 666 | + |
| 667 | + @staticmethod |
| 668 | + def _validate_tokens(total_tokens): |
| 669 | + if not isinstance(total_tokens, (int, float)): |
| 670 | + raise TypeError("'total_tokens' must be an int or float") |
| 671 | + if isinstance(total_tokens, float) and total_tokens != math.inf: |
| 672 | + raise ValueError( |
| 673 | + "'total_tokens' must be an integer or math.inf" |
| 674 | + ) |
| 675 | + if total_tokens < 0: |
| 676 | + raise ValueError("'total_tokens' must be >= 0") |
| 677 | + |
| 678 | + @property |
| 679 | + def total_tokens(self) -> int | float: |
| 680 | + """The total number of tokens available (read-write).""" |
| 681 | + return self._total_tokens |
| 682 | + |
| 683 | + @total_tokens.setter |
| 684 | + def total_tokens(self, value: int | float): |
| 685 | + self._validate_tokens(value) |
| 686 | + self._total_tokens = value |
| 687 | + self._notify_waiters() |
| 688 | + |
| 689 | + @property |
| 690 | + def borrowed_tokens(self) -> int: |
| 691 | + """The number of tokens currently borrowed.""" |
| 692 | + return len(self._borrowers) |
| 693 | + |
| 694 | + @property |
| 695 | + def available_tokens(self) -> int | float: |
| 696 | + """The number of tokens currently available.""" |
| 697 | + return self._total_tokens - len(self._borrowers) |
| 698 | + |
| 699 | + def acquire_nowait(self) -> None: |
| 700 | + """Acquire a token on behalf of the current task without blocking. |
| 701 | +
|
| 702 | + Raises WouldBlock if a token is not immediately available. |
| 703 | + Raises RuntimeError if the current task already holds a token. |
| 704 | + """ |
| 705 | + from . import tasks |
| 706 | + self.acquire_on_behalf_of_nowait(tasks.current_task()) |
| 707 | + |
| 708 | + async def acquire(self) -> None: |
| 709 | + """Acquire a token on behalf of the current task. |
| 710 | +
|
| 711 | + Blocks until a token is available. |
| 712 | + Raises RuntimeError if the current task already holds a token. |
| 713 | + """ |
| 714 | + from . import tasks |
| 715 | + await self.acquire_on_behalf_of(tasks.current_task()) |
| 716 | + |
| 717 | + def acquire_on_behalf_of_nowait(self, borrower) -> None: |
| 718 | + """Acquire a token on behalf of the given borrower without blocking. |
| 719 | +
|
| 720 | + Raises WouldBlock if a token is not immediately available. |
| 721 | + Raises RuntimeError if the borrower already holds a token. |
| 722 | + """ |
| 723 | + if borrower in self._borrowers: |
| 724 | + raise RuntimeError( |
| 725 | + "this borrower is already holding one of this " |
| 726 | + "CapacityLimiter's tokens" |
| 727 | + ) |
| 728 | + if self._waiters or len(self._borrowers) >= self._total_tokens: |
| 729 | + raise exceptions.WouldBlock |
| 730 | + self._borrowers.add(borrower) |
| 731 | + |
| 732 | + async def acquire_on_behalf_of(self, borrower) -> None: |
| 733 | + """Acquire a token on behalf of the given borrower. |
| 734 | +
|
| 735 | + Blocks until a token is available. |
| 736 | + Raises RuntimeError if the borrower already holds a token. |
| 737 | + """ |
| 738 | + try: |
| 739 | + self.acquire_on_behalf_of_nowait(borrower) |
| 740 | + except exceptions.WouldBlock: |
| 741 | + pass |
| 742 | + else: |
| 743 | + return |
| 744 | + |
| 745 | + fut = self._get_loop().create_future() |
| 746 | + self._waiters[borrower] = fut |
| 747 | + try: |
| 748 | + await fut |
| 749 | + except exceptions.CancelledError: |
| 750 | + self._waiters.pop(borrower, None) |
| 751 | + # If the future was already resolved before we got cancelled, |
| 752 | + # we already hold the token — release it and wake the next waiter. |
| 753 | + if fut.done() and not fut.cancelled(): |
| 754 | + self._borrowers.discard(borrower) |
| 755 | + self._notify_waiters() |
| 756 | + raise |
| 757 | + else: |
| 758 | + # Future completed successfully; borrower was added by |
| 759 | + # _notify_waiters, nothing more to do. |
| 760 | + pass |
| 761 | + |
| 762 | + def release(self) -> None: |
| 763 | + """Release a token on behalf of the current task. |
| 764 | +
|
| 765 | + Raises RuntimeError if the current task does not hold a token. |
| 766 | + """ |
| 767 | + from . import tasks |
| 768 | + self.release_on_behalf_of(tasks.current_task()) |
| 769 | + |
| 770 | + def release_on_behalf_of(self, borrower) -> None: |
| 771 | + """Release a token on behalf of the given borrower. |
| 772 | +
|
| 773 | + Raises RuntimeError if the borrower does not hold a token. |
| 774 | + """ |
| 775 | + if borrower not in self._borrowers: |
| 776 | + raise RuntimeError( |
| 777 | + "this borrower is not holding any of this " |
| 778 | + "CapacityLimiter's tokens" |
| 779 | + ) |
| 780 | + self._borrowers.discard(borrower) |
| 781 | + self._notify_waiters() |
| 782 | + |
| 783 | + def _notify_waiters(self): |
| 784 | + """Wake up waiters while capacity is available.""" |
| 785 | + while self._waiters and len(self._borrowers) < self._total_tokens: |
| 786 | + borrower, fut = self._waiters.popitem(last=False) |
| 787 | + if not fut.done(): |
| 788 | + self._borrowers.add(borrower) |
| 789 | + fut.set_result(None) |
| 790 | + |
| 791 | + def statistics(self) -> CapacityLimiterStatistics: |
| 792 | + """Return statistics about the current state of the limiter.""" |
| 793 | + return CapacityLimiterStatistics( |
| 794 | + borrowed_tokens=len(self._borrowers), |
| 795 | + total_tokens=self._total_tokens, |
| 796 | + borrowers=tuple(self._borrowers), |
| 797 | + tasks_waiting=len(self._waiters), |
| 798 | + ) |
0 commit comments