Skip to content

Commit f0ddd03

Browse files
committed
update prefetcher logic
1 parent de3a928 commit f0ddd03

2 files changed

Lines changed: 116 additions & 48 deletions

File tree

gcsfs/prefetcher.py

Lines changed: 110 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,24 @@ def average(self) -> int:
7878
return 1024 * 1024 # 1MB
7979
return self._sum // count
8080

81+
@property
82+
def is_variable(self) -> bool:
83+
"""Determines if the history contains distinct chunk sizes."""
84+
count = len(self._history)
85+
if count < 2:
86+
return False
87+
88+
first_val = self._history[0]
89+
return any(val != first_val for val in self._history)
90+
91+
@property
92+
def last_value(self) -> int:
93+
"""Returns the most recent entry in the history."""
94+
if not self._history:
95+
raise RuntimeError("No entry found in history")
96+
97+
return self._history[-1]
98+
8199
def clear(self):
82100
"""Clears the history and resets the sum to zero."""
83101
logger.debug("Clearing RunningAverageTracker history.")
@@ -101,17 +119,34 @@ class PrefetchProducer:
101119
# to maximum of 2 * io_size and 128MB
102120
MIN_PREFETCH_SIZE = 128 * 1024 * 1024
103121

122+
# The prefetching starts on the third read.
123+
MIN_STREAKS_FOR_PREFETCHING = 3
124+
125+
# Threshold for disabling proactive prefetching on large, variable reads.
126+
#
127+
# If the average read size exceeds this value and patterns are variable,
128+
# prefetching shifts from an I/O bottleneck to a CPU bottleneck. When a user
129+
# requests random massive sizes (e.g., jumping between 100MB and INF), the
130+
# producer still fetches chunks based on the rolling average. The consumer
131+
# then has to pick up multiple chunks and stitch them together to match the
132+
# exact requested size.
133+
#
134+
# For small average read sizes, this byte assembly is fast and the bottleneck
135+
# remains the network I/O. However, for massive reads (>= 100MB), the extra
136+
# step of copying and assembling huge byte strings in memory severely slows
137+
# down the operation.
138+
VARIABLE_IO_THRESHOLD = 100 * 1024 * 1024
139+
104140
def __init__(
105141
self,
106142
fetcher,
107143
size: int,
108144
concurrency: int,
109145
queue: asyncio.Queue,
110146
wakeup_event: asyncio.Event,
111-
get_user_offset,
112-
get_io_size,
113-
get_sequential_streak,
114-
on_error,
147+
consumer: "PrefetchConsumer",
148+
tracker: RunningAverageTracker,
149+
orchestrator: "BackgroundPrefetcher",
115150
user_max_prefetch_size=None,
116151
):
117152
"""Initializes the background producer.
@@ -122,10 +157,9 @@ def __init__(
122157
concurrency (int): Maximum number of concurrent fetch tasks.
123158
queue (asyncio.Queue): The shared queue to push download tasks into.
124159
wakeup_event (asyncio.Event): Event used to wake the producer from an idle state.
125-
get_user_offset (Callable): Function returning the user's current read offset.
126-
get_io_size (Callable): Function returning the adaptive IO size.
127-
get_sequential_streak (Callable): Function returning the current sequential read streak.
128-
on_error (Callable): Callback triggered when a background error occurs.
160+
consumer (PrefetchConsumer): The consumer reading the prefetched chunks.
161+
tracker (RunningAverageTracker): Tracker for history of read sizes.
162+
orchestrator (BackgroundPrefetcher): The parent object managing the operation.
129163
user_max_prefetch_size (int, optional): A hard limit for prefetch size overrides.
130164
"""
131165
logger.debug(
@@ -140,10 +174,9 @@ def __init__(
140174
self.queue = queue
141175
self.wakeup_event = wakeup_event
142176

143-
self.get_user_offset = get_user_offset
144-
self.get_io_size = get_io_size
145-
self.get_sequential_streak = get_sequential_streak
146-
self.on_error = on_error
177+
self.consumer = consumer
178+
self.tracker = tracker
179+
self.orchestrator = orchestrator
147180
self._user_max_prefetch_size = user_max_prefetch_size
148181

149182
self.current_offset = 0
@@ -161,9 +194,9 @@ def max_prefetch_size(self) -> int:
161194
if self._user_max_prefetch_size is not None:
162195
return min(
163196
self._user_max_prefetch_size,
164-
max(2 * self.get_io_size(), self.MIN_PREFETCH_SIZE),
197+
max(2 * self.tracker.average, self.MIN_PREFETCH_SIZE),
165198
)
166-
return max(2 * self.get_io_size(), self.MIN_PREFETCH_SIZE)
199+
return max(2 * self.tracker.average, self.MIN_PREFETCH_SIZE)
167200

168201
def start(self):
169202
"""Starts the background producer loop.
@@ -246,23 +279,44 @@ async def _loop(self):
246279
if self.is_stopped:
247280
break
248281

249-
io_size = self.get_io_size()
250-
streak = self.get_sequential_streak()
251-
prefetch_size = min((streak + 1) * io_size, self.max_prefetch_size)
282+
avg_io_size = self.tracker.average
283+
streak = self.consumer.sequential_streak
284+
is_variable = self.tracker.is_variable
285+
last_read_size = self.tracker.last_value
286+
287+
# Disable prefetching ahead if highly variable AND average > 100MB
288+
if is_variable and avg_io_size > PrefetchProducer.VARIABLE_IO_THRESHOLD:
289+
logger.debug(
290+
"Highly variable large IO detected. Disabling background prefetching."
291+
)
292+
prefetch_multiplier = 1
293+
elif streak < self.MIN_STREAKS_FOR_PREFETCHING:
294+
prefetch_multiplier = 1
295+
else:
296+
prefetch_multiplier = streak - self.MIN_STREAKS_FOR_PREFETCHING + 1
297+
298+
if self.queue.empty() or prefetch_multiplier == 1:
299+
io_size = last_read_size
300+
else:
301+
io_size = avg_io_size
302+
303+
prefetch_size = min(
304+
prefetch_multiplier * io_size, self.max_prefetch_size
305+
)
252306

253307
logger.debug(
254308
"Producer awake. Current offset: %d, User offset: %d, Prefetch size: %d",
255309
self.current_offset,
256-
self.get_user_offset(),
310+
self.consumer.offset,
257311
prefetch_size,
258312
)
259313

260314
while (
261315
not self.is_stopped
262-
and (self.current_offset - self.get_user_offset()) < prefetch_size
316+
and (self.current_offset - self.consumer.offset) < prefetch_size
263317
and self.current_offset < self.size
264318
):
265-
user_offset = self.get_user_offset()
319+
user_offset = self.consumer.offset
266320
space_remaining = self.size - self.current_offset
267321
prefetch_space_available = prefetch_size - (
268322
self.current_offset - user_offset
@@ -317,7 +371,7 @@ async def _loop(self):
317371
exc_info=True,
318372
)
319373
self.is_stopped = True
320-
self.on_error(e)
374+
self.orchestrator._set_error(e)
321375
await self.queue.put(e)
322376

323377

@@ -332,22 +386,22 @@ def __init__(
332386
self,
333387
queue: asyncio.Queue,
334388
wakeup_event: asyncio.Event,
335-
is_producer_stopped,
336-
on_error,
389+
tracker: RunningAverageTracker,
390+
orchestrator: "BackgroundPrefetcher",
337391
):
338392
"""Initializes the consumer.
339393
340394
Args:
341395
queue (asyncio.Queue): The shared queue containing fetch tasks.
342396
wakeup_event (asyncio.Event): Event used to wake the producer when more data is needed.
343-
is_producer_stopped (Callable): Function returning whether the producer has been halted.
344-
on_error (Callable): Callback triggered when a fetch error is encountered.
397+
tracker (RunningAverageTracker): Tracker for history of read sizes.
398+
orchestrator (BackgroundPrefetcher): The parent object managing the operation.
345399
"""
346400
logger.debug("Initializing PrefetchConsumer.")
347401
self.queue = queue
348402
self.wakeup_event = wakeup_event
349-
self.is_producer_stopped = is_producer_stopped
350-
self.on_error = on_error
403+
self.tracker = tracker
404+
self.orchestrator = orchestrator
351405
self.sequential_streak = 0
352406
self.offset = 0
353407
self._current_block = b""
@@ -389,7 +443,11 @@ async def _advance(self, size: int, save_data: bool) -> list[bytes]:
389443
available = len(self._current_block) - self._current_block_idx
390444

391445
if not available:
392-
if self.is_producer_stopped() and self.queue.empty():
446+
is_producer_stopped = (
447+
not hasattr(self.orchestrator, "producer")
448+
or self.orchestrator.producer.is_stopped
449+
)
450+
if is_producer_stopped and self.queue.empty():
393451
logger.debug("Consumer reached EOF.")
394452
break
395453

@@ -401,15 +459,30 @@ async def _advance(self, size: int, save_data: bool) -> list[bytes]:
401459

402460
if isinstance(task, Exception):
403461
logger.error("Consumer retrieved an exception: %s", task)
404-
self.on_error(task)
462+
self.orchestrator._set_error(task)
405463
raise task
406464

407465
try:
408466
block = await task
409467

410468
self.sequential_streak += 1
411-
if self.sequential_streak >= 2:
412-
self.wakeup_event.set()
469+
if (
470+
self.sequential_streak
471+
>= PrefetchProducer.MIN_STREAKS_FOR_PREFETCHING
472+
):
473+
is_variable = self.tracker.is_variable
474+
avg_io_size = self.tracker.average
475+
476+
# Suppress proactive wakeups to prevent large CPU assembly on erratic large reads
477+
if not (
478+
is_variable
479+
and avg_io_size > PrefetchProducer.VARIABLE_IO_THRESHOLD
480+
):
481+
self.wakeup_event.set()
482+
else:
483+
logger.debug(
484+
"Suppressing proactive producer wakeup due to massive variable workload."
485+
)
413486

414487
self._current_block = block
415488
self._current_block_idx = 0
@@ -418,7 +491,7 @@ async def _advance(self, size: int, save_data: bool) -> list[bytes]:
418491
raise
419492
except Exception as e:
420493
logger.error("Consumer caught an error: %s", e, exc_info=True)
421-
self.on_error(e)
494+
self.orchestrator._set_error(e)
422495
raise e
423496

424497
if not self._current_block:
@@ -523,8 +596,8 @@ def __init__(self, fetcher, size: int, concurrency: int, max_prefetch_size=None)
523596
self.consumer = PrefetchConsumer(
524597
queue=self.queue,
525598
wakeup_event=self.wakeup_event,
526-
is_producer_stopped=self._is_producer_stopped,
527-
on_error=self._set_error,
599+
tracker=self.read_tracker,
600+
orchestrator=self,
528601
)
529602

530603
self.producer = PrefetchProducer(
@@ -533,10 +606,9 @@ def __init__(self, fetcher, size: int, concurrency: int, max_prefetch_size=None)
533606
concurrency=self.concurrency,
534607
queue=self.queue,
535608
wakeup_event=self.wakeup_event,
536-
get_user_offset=lambda: self.consumer.offset,
537-
get_io_size=self._get_adaptive_io_size,
538-
get_sequential_streak=lambda: self.consumer.sequential_streak,
539-
on_error=self._set_error,
609+
consumer=self.consumer,
610+
tracker=self.read_tracker,
611+
orchestrator=self,
540612
user_max_prefetch_size=max_prefetch_size,
541613
)
542614

@@ -554,12 +626,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
554626
"""Context manager exit point. Ensures the prefetcher is cleanly closed."""
555627
self.close()
556628

557-
def _get_adaptive_io_size(self) -> int:
558-
return self.read_tracker.average
559-
560-
def _is_producer_stopped(self) -> bool:
561-
return self.producer.is_stopped if hasattr(self, "producer") else True
562-
563629
def _set_error(self, e: Exception):
564630
logger.error("Global error state set in BackgroundPrefetcher: %s", e)
565631
self._error = e

gcsfs/tests/test_prefetcher.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def test_producer_concurrency_streak_and_min_chunk():
136136
original_min_chunk = bp.producer.MIN_CHUNK_SIZE
137137
bp.producer.MIN_CHUNK_SIZE = 10
138138

139-
bp._fetch(0, 50)
140-
bp._fetch(50, 100)
141-
bp._fetch(100, 150)
139+
# Do 6 reads to push the streak well past the MIN_STREAKS threshold
140+
for i in range(6):
141+
bp._fetch(i * 50, (i + 1) * 50)
142142

143143
fsspec.asyn.sync(bp.loop, asyncio.sleep, 0.1)
144144

@@ -507,7 +507,9 @@ def test_producer_min_chunk_inner_break():
507507
async def trigger_loop():
508508
bp.producer.current_offset = 250
509509
bp.consumer.offset = 0
510-
bp.consumer.sequential_streak = 3 # makes prefetch_size = (3+1) * 100 = 400
510+
# streak=6 makes prefetch_multiplier = 4 (6 - 3 + 1)
511+
# prefetch_size = 4 * 100 = 400
512+
bp.consumer.sequential_streak = 6
511513
bp.wakeup_event.set()
512514
await asyncio.sleep(0.05)
513515

0 commit comments

Comments
 (0)