-
Notifications
You must be signed in to change notification settings - Fork 175
Refactor prefetcher code #818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -78,6 +78,24 @@ def average(self) -> int: | |||||||
| return 1024 * 1024 # 1MB | ||||||||
| return self._sum // count | ||||||||
|
|
||||||||
| @property | ||||||||
| def is_variable(self) -> bool: | ||||||||
| """Determines if the history contains distinct chunk sizes.""" | ||||||||
| count = len(self._history) | ||||||||
| if count < 2: | ||||||||
| return False | ||||||||
|
|
||||||||
| first_val = self._history[0] | ||||||||
| return any(val != first_val for val in self._history) | ||||||||
|
Comment on lines
+88
to
+89
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? Whether this is faster probably depends on where the first non-equal value is and how long the list can be. |
||||||||
|
|
||||||||
| @property | ||||||||
| def last_value(self) -> int: | ||||||||
| """Returns the most recent entry in the history.""" | ||||||||
| if not self._history: | ||||||||
| raise RuntimeError("No entry found in history") | ||||||||
|
|
||||||||
| return self._history[-1] | ||||||||
|
|
||||||||
| def clear(self): | ||||||||
| """Clears the history and resets the sum to zero.""" | ||||||||
| logger.debug("Clearing RunningAverageTracker history.") | ||||||||
|
|
@@ -101,17 +119,34 @@ class PrefetchProducer: | |||||||
| # to maximum of 2 * io_size and 128MB | ||||||||
| MIN_PREFETCH_SIZE = 128 * 1024 * 1024 | ||||||||
|
|
||||||||
| # The prefetching starts on the third read. | ||||||||
| MIN_STREAKS_FOR_PREFETCHING = 3 | ||||||||
|
|
||||||||
| # Threshold for disabling proactive prefetching on large, variable reads. | ||||||||
| # | ||||||||
| # If the average read size exceeds this value and patterns are variable, | ||||||||
| # prefetching shifts from an I/O bottleneck to a CPU bottleneck. When a user | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| # requests random massive sizes (e.g., jumping between 64MB and INF), the | ||||||||
| # producer still fetches chunks based on the rolling average. The consumer | ||||||||
| # then has to pick up multiple chunks and stitch them together to match the | ||||||||
| # exact requested size. | ||||||||
| # | ||||||||
| # For small average read sizes, this byte assembly is fast and the bottleneck | ||||||||
| # remains the network I/O. However, for massive reads (>= 64MB), the extra | ||||||||
| # step of copying and assembling huge byte strings in memory severely slows | ||||||||
| # down the operation. | ||||||||
| VARIABLE_IO_THRESHOLD = 64 * 1024 * 1024 | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the best value of this depend on the network bandwidth? I bet on slow connections, we always prefer any amount of prefetching and copy time is irrelevant. |
||||||||
|
|
||||||||
| def __init__( | ||||||||
| self, | ||||||||
| fetcher, | ||||||||
| size: int, | ||||||||
| concurrency: int, | ||||||||
| queue: asyncio.Queue, | ||||||||
| wakeup_event: asyncio.Event, | ||||||||
| get_user_offset, | ||||||||
| get_io_size, | ||||||||
| get_sequential_streak, | ||||||||
| on_error, | ||||||||
| consumer: "PrefetchConsumer", | ||||||||
| tracker: RunningAverageTracker, | ||||||||
| orchestrator: "BackgroundPrefetcher", | ||||||||
| user_max_prefetch_size=None, | ||||||||
| ): | ||||||||
| """Initializes the background producer. | ||||||||
|
|
@@ -122,10 +157,9 @@ def __init__( | |||||||
| concurrency (int): Maximum number of concurrent fetch tasks. | ||||||||
| queue (asyncio.Queue): The shared queue to push download tasks into. | ||||||||
| wakeup_event (asyncio.Event): Event used to wake the producer from an idle state. | ||||||||
| get_user_offset (Callable): Function returning the user's current read offset. | ||||||||
| get_io_size (Callable): Function returning the adaptive IO size. | ||||||||
| get_sequential_streak (Callable): Function returning the current sequential read streak. | ||||||||
| on_error (Callable): Callback triggered when a background error occurs. | ||||||||
| consumer (PrefetchConsumer): The consumer reading the prefetched chunks. | ||||||||
| tracker (RunningAverageTracker): Tracker for history of read sizes. | ||||||||
| orchestrator (BackgroundPrefetcher): The parent object managing the operation. | ||||||||
| user_max_prefetch_size (int, optional): A hard limit for prefetch size overrides. | ||||||||
| """ | ||||||||
| logger.debug( | ||||||||
|
|
@@ -140,10 +174,9 @@ def __init__( | |||||||
| self.queue = queue | ||||||||
| self.wakeup_event = wakeup_event | ||||||||
|
|
||||||||
| self.get_user_offset = get_user_offset | ||||||||
| self.get_io_size = get_io_size | ||||||||
| self.get_sequential_streak = get_sequential_streak | ||||||||
| self.on_error = on_error | ||||||||
| self.consumer = consumer | ||||||||
| self.tracker = tracker | ||||||||
| self.orchestrator = orchestrator | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reference cycle - prefer a weakref? |
||||||||
| self._user_max_prefetch_size = user_max_prefetch_size | ||||||||
|
|
||||||||
| self.current_offset = 0 | ||||||||
|
|
@@ -161,9 +194,9 @@ def max_prefetch_size(self) -> int: | |||||||
| if self._user_max_prefetch_size is not None: | ||||||||
| return min( | ||||||||
| self._user_max_prefetch_size, | ||||||||
| max(2 * self.get_io_size(), self.MIN_PREFETCH_SIZE), | ||||||||
| max(2 * self.tracker.average, self.MIN_PREFETCH_SIZE), | ||||||||
| ) | ||||||||
| return max(2 * self.get_io_size(), self.MIN_PREFETCH_SIZE) | ||||||||
| return max(2 * self.tracker.average, self.MIN_PREFETCH_SIZE) | ||||||||
|
|
||||||||
| def start(self): | ||||||||
| """Starts the background producer loop. | ||||||||
|
|
@@ -246,23 +279,60 @@ async def _loop(self): | |||||||
| if self.is_stopped: | ||||||||
| break | ||||||||
|
|
||||||||
| io_size = self.get_io_size() | ||||||||
| streak = self.get_sequential_streak() | ||||||||
| prefetch_size = min((streak + 1) * io_size, self.max_prefetch_size) | ||||||||
| avg_io_size = self.tracker.average | ||||||||
| streak = self.consumer.sequential_streak | ||||||||
| is_variable = self.tracker.is_variable | ||||||||
| last_read_size = self.tracker.last_value | ||||||||
|
|
||||||||
| exceeds_user_max = ( | ||||||||
| self._user_max_prefetch_size is not None | ||||||||
| and avg_io_size > self._user_max_prefetch_size | ||||||||
| ) | ||||||||
|
|
||||||||
| # Disable prefetching ahead if highly variable AND average > 64MB, or if it exceeds user max | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If anything other than exactly constant, no? |
||||||||
| if ( | ||||||||
| is_variable and avg_io_size > PrefetchProducer.VARIABLE_IO_THRESHOLD | ||||||||
| ) or exceeds_user_max: | ||||||||
| logger.debug( | ||||||||
| "Large IO detected (variable > 64MB or > user max). Disabling background prefetching." | ||||||||
| ) | ||||||||
| prefetch_multiplier = 1 | ||||||||
| elif streak < self.MIN_STREAKS_FOR_PREFETCHING: | ||||||||
| prefetch_multiplier = 1 | ||||||||
| else: | ||||||||
| prefetch_multiplier = streak - self.MIN_STREAKS_FOR_PREFETCHING + 1 | ||||||||
|
|
||||||||
| if self.queue.empty() or prefetch_multiplier == 1: | ||||||||
| io_size = last_read_size | ||||||||
| else: | ||||||||
| io_size = avg_io_size | ||||||||
|
|
||||||||
| prefetch_size = min( | ||||||||
| prefetch_multiplier * io_size, self.max_prefetch_size | ||||||||
| ) | ||||||||
| if self.consumer.offset + prefetch_size < self.consumer.target_offset: | ||||||||
| prefetch_size = self.consumer.target_offset - self.consumer.offset | ||||||||
|
|
||||||||
| if is_variable: | ||||||||
| effective_prefetch_size = prefetch_size | ||||||||
| else: | ||||||||
| effective_prefetch_size = (prefetch_size // io_size) * io_size | ||||||||
| if effective_prefetch_size == 0: | ||||||||
| effective_prefetch_size = prefetch_size | ||||||||
|
|
||||||||
| logger.debug( | ||||||||
| "Producer awake. Current offset: %d, User offset: %d, Prefetch size: %d", | ||||||||
| self.current_offset, | ||||||||
| self.get_user_offset(), | ||||||||
| self.consumer.offset, | ||||||||
| prefetch_size, | ||||||||
| ) | ||||||||
|
|
||||||||
| while ( | ||||||||
| not self.is_stopped | ||||||||
| and (self.current_offset - self.get_user_offset()) < prefetch_size | ||||||||
| and (self.current_offset - self.consumer.offset) < prefetch_size | ||||||||
| and self.current_offset < self.size | ||||||||
| ): | ||||||||
| user_offset = self.get_user_offset() | ||||||||
| user_offset = self.consumer.offset | ||||||||
| space_remaining = self.size - self.current_offset | ||||||||
| prefetch_space_available = prefetch_size - ( | ||||||||
| self.current_offset - user_offset | ||||||||
|
|
@@ -278,14 +348,22 @@ async def _loop(self): | |||||||
| else: | ||||||||
| actual_size = min(io_size, space_remaining) | ||||||||
|
|
||||||||
| if streak < 2: | ||||||||
| if prefetch_space_available < actual_size: | ||||||||
| if is_variable or prefetch_space_available == prefetch_size: | ||||||||
| actual_size = prefetch_space_available | ||||||||
| else: | ||||||||
| break | ||||||||
|
|
||||||||
| if streak < PrefetchProducer.MIN_STREAKS_FOR_PREFETCHING: | ||||||||
| sfactor = self.concurrency | ||||||||
| else: | ||||||||
| sfactor = min( | ||||||||
| self.concurrency, | ||||||||
| max( | ||||||||
| 1, | ||||||||
| actual_size * self.concurrency // prefetch_size, | ||||||||
| actual_size | ||||||||
| * self.concurrency | ||||||||
| // effective_prefetch_size, | ||||||||
| ), | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
@@ -317,7 +395,7 @@ async def _loop(self): | |||||||
| exc_info=True, | ||||||||
| ) | ||||||||
| self.is_stopped = True | ||||||||
| self.on_error(e) | ||||||||
| self.orchestrator._set_error(e) | ||||||||
| await self.queue.put(e) | ||||||||
|
|
||||||||
|
|
||||||||
|
|
@@ -332,24 +410,25 @@ def __init__( | |||||||
| self, | ||||||||
| queue: asyncio.Queue, | ||||||||
| wakeup_event: asyncio.Event, | ||||||||
| is_producer_stopped, | ||||||||
| on_error, | ||||||||
| tracker: RunningAverageTracker, | ||||||||
| orchestrator: "BackgroundPrefetcher", | ||||||||
| ): | ||||||||
| """Initializes the consumer. | ||||||||
|
|
||||||||
| Args: | ||||||||
| queue (asyncio.Queue): The shared queue containing fetch tasks. | ||||||||
| wakeup_event (asyncio.Event): Event used to wake the producer when more data is needed. | ||||||||
| is_producer_stopped (Callable): Function returning whether the producer has been halted. | ||||||||
| on_error (Callable): Callback triggered when a fetch error is encountered. | ||||||||
| tracker (RunningAverageTracker): Tracker for history of read sizes. | ||||||||
| orchestrator (BackgroundPrefetcher): The parent object managing the operation. | ||||||||
| """ | ||||||||
| logger.debug("Initializing PrefetchConsumer.") | ||||||||
| self.queue = queue | ||||||||
| self.wakeup_event = wakeup_event | ||||||||
| self.is_producer_stopped = is_producer_stopped | ||||||||
| self.on_error = on_error | ||||||||
| self.tracker = tracker | ||||||||
| self.orchestrator = orchestrator | ||||||||
| self.sequential_streak = 0 | ||||||||
| self.offset = 0 | ||||||||
| self.target_offset = 0 | ||||||||
| self._current_block = b"" | ||||||||
| self._current_block_idx = 0 | ||||||||
|
|
||||||||
|
|
@@ -364,6 +443,7 @@ def seek(self, new_offset: int): | |||||||
| new_offset, | ||||||||
| ) | ||||||||
| self.offset = new_offset | ||||||||
| self.target_offset = new_offset | ||||||||
| self.sequential_streak = 0 | ||||||||
| self._current_block = b"" | ||||||||
| self._current_block_idx = 0 | ||||||||
|
|
@@ -384,12 +464,18 @@ async def _advance(self, size: int, save_data: bool) -> list[bytes]: | |||||||
|
|
||||||||
| chunks = [] | ||||||||
| processed = 0 | ||||||||
| self.target_offset = self.offset + size | ||||||||
|
|
||||||||
| while processed < size: | ||||||||
| available = len(self._current_block) - self._current_block_idx | ||||||||
| trigger_wakeup = False | ||||||||
|
|
||||||||
| if not available: | ||||||||
| if self.is_producer_stopped() and self.queue.empty(): | ||||||||
| is_producer_stopped = ( | ||||||||
| not hasattr(self.orchestrator, "producer") | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: how about setting the producer to None then we do
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (or set it to None as a class attribute) |
||||||||
| or self.orchestrator.producer.is_stopped | ||||||||
| ) | ||||||||
| if is_producer_stopped and self.queue.empty(): | ||||||||
| logger.debug("Consumer reached EOF.") | ||||||||
| break | ||||||||
|
|
||||||||
|
|
@@ -401,15 +487,38 @@ async def _advance(self, size: int, save_data: bool) -> list[bytes]: | |||||||
|
|
||||||||
| if isinstance(task, Exception): | ||||||||
| logger.error("Consumer retrieved an exception: %s", task) | ||||||||
| self.on_error(task) | ||||||||
| self.orchestrator._set_error(task) | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _set_error doesn't need the leading _ really - the only calling context is from this class, and it is meant to be called. |
||||||||
| raise task | ||||||||
|
|
||||||||
| try: | ||||||||
| block = await task | ||||||||
|
|
||||||||
| self.sequential_streak += 1 | ||||||||
| if self.sequential_streak >= 2: | ||||||||
| self.wakeup_event.set() | ||||||||
| if ( | ||||||||
| self.sequential_streak | ||||||||
| >= PrefetchProducer.MIN_STREAKS_FOR_PREFETCHING | ||||||||
| ): | ||||||||
| is_variable = self.tracker.is_variable | ||||||||
| avg_io_size = self.tracker.average | ||||||||
|
|
||||||||
| exceeds_user_max = ( | ||||||||
| self.orchestrator.max_prefetch_size is not None | ||||||||
| and avg_io_size > self.orchestrator.max_prefetch_size | ||||||||
| ) | ||||||||
| is_massive_variable = ( | ||||||||
| is_variable | ||||||||
| and avg_io_size > PrefetchProducer.VARIABLE_IO_THRESHOLD | ||||||||
| ) | ||||||||
|
|
||||||||
| # Suppress proactive wakeups to prevent large CPU assembly | ||||||||
| # on erratic large reads or exceeding max | ||||||||
| if not (is_massive_variable or exceeds_user_max): | ||||||||
| trigger_wakeup = True | ||||||||
| else: | ||||||||
| logger.debug( | ||||||||
| "Suppressing proactive producer wakeup due to massive variable" | ||||||||
| " workload or exceeding user max prefetch." | ||||||||
| ) | ||||||||
|
|
||||||||
| self._current_block = block | ||||||||
| self._current_block_idx = 0 | ||||||||
|
|
@@ -418,7 +527,7 @@ async def _advance(self, size: int, save_data: bool) -> list[bytes]: | |||||||
| raise | ||||||||
| except Exception as e: | ||||||||
| logger.error("Consumer caught an error: %s", e, exc_info=True) | ||||||||
| self.on_error(e) | ||||||||
| self.orchestrator._set_error(e) | ||||||||
| raise e | ||||||||
|
|
||||||||
| if not self._current_block: | ||||||||
|
|
@@ -440,6 +549,8 @@ async def _advance(self, size: int, save_data: bool) -> list[bytes]: | |||||||
| self._current_block_idx += take | ||||||||
| processed += take | ||||||||
| self.offset += take | ||||||||
| if trigger_wakeup: | ||||||||
| self.wakeup_event.set() | ||||||||
|
|
||||||||
| return chunks | ||||||||
|
|
||||||||
|
|
@@ -504,6 +615,7 @@ def __init__(self, fetcher, size: int, concurrency: int, max_prefetch_size=None) | |||||||
| ) | ||||||||
| self.size = size | ||||||||
| self.concurrency = concurrency | ||||||||
| self.max_prefetch_size = max_prefetch_size | ||||||||
|
|
||||||||
| if max_prefetch_size is not None and max_prefetch_size <= 0: | ||||||||
| logger.error("Invalid max_prefetch_size provided: %s", max_prefetch_size) | ||||||||
|
|
@@ -523,8 +635,8 @@ def __init__(self, fetcher, size: int, concurrency: int, max_prefetch_size=None) | |||||||
| self.consumer = PrefetchConsumer( | ||||||||
| queue=self.queue, | ||||||||
| wakeup_event=self.wakeup_event, | ||||||||
| is_producer_stopped=self._is_producer_stopped, | ||||||||
| on_error=self._set_error, | ||||||||
| tracker=self.read_tracker, | ||||||||
| orchestrator=self, | ||||||||
| ) | ||||||||
|
|
||||||||
| self.producer = PrefetchProducer( | ||||||||
|
|
@@ -533,10 +645,9 @@ def __init__(self, fetcher, size: int, concurrency: int, max_prefetch_size=None) | |||||||
| concurrency=self.concurrency, | ||||||||
| queue=self.queue, | ||||||||
| wakeup_event=self.wakeup_event, | ||||||||
| get_user_offset=lambda: self.consumer.offset, | ||||||||
| get_io_size=self._get_adaptive_io_size, | ||||||||
| get_sequential_streak=lambda: self.consumer.sequential_streak, | ||||||||
| on_error=self._set_error, | ||||||||
| consumer=self.consumer, | ||||||||
| tracker=self.read_tracker, | ||||||||
| orchestrator=self, | ||||||||
| user_max_prefetch_size=max_prefetch_size, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
@@ -554,12 +665,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): | |||||||
| """Context manager exit point. Ensures the prefetcher is cleanly closed.""" | ||||||||
| self.close() | ||||||||
|
|
||||||||
| def _get_adaptive_io_size(self) -> int: | ||||||||
| return self.read_tracker.average | ||||||||
|
|
||||||||
| def _is_producer_stopped(self) -> bool: | ||||||||
| return self.producer.is_stopped if hasattr(self, "producer") else True | ||||||||
|
|
||||||||
| def _set_error(self, e: Exception): | ||||||||
| logger.error("Global error state set in BackgroundPrefetcher: %s", e) | ||||||||
| self._error = e | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing superlatives - I think it's clearer with neutral language.