3333
3434from __future__ import annotations
3535
36+ import contextvars
3637import queue
3738import sys
3839from threading import Lock , Thread
@@ -237,12 +238,18 @@ def _allmsgs(obj):
237238class PipelineThread (Thread ):
238239 """Abstract base class for pipeline-stage threads."""
239240
240- def __init__ (self , all_threads ):
241+ def __init__ (self , all_threads , ctx : contextvars . Context | None = None ):
241242 super ().__init__ ()
242243 self .abort_lock = Lock ()
243244 self .abort_flag = False
244245 self .all_threads = all_threads
245246 self .exc_info = None
247+ self .ctx = ctx
248+
249+ def _run_in_context (self , func , * args ):
250+ if self .ctx is None :
251+ return func (* args )
252+ return self .ctx .run (func , * args )
246253
247254 def abort (self ):
248255 """Shut down the thread at the next chance possible."""
@@ -267,8 +274,8 @@ class FirstPipelineThread(PipelineThread):
267274 The coroutine should just be a generator.
268275 """
269276
270- def __init__ (self , coro , out_queue , all_threads ):
271- super ().__init__ (all_threads )
277+ def __init__ (self , coro , out_queue , all_threads , ctx = None ):
278+ super ().__init__ (all_threads , ctx )
272279 self .coro = coro
273280 self .out_queue = out_queue
274281 self .out_queue .acquire ()
@@ -282,7 +289,7 @@ def run(self):
282289
283290 # Get the value from the generator.
284291 try :
285- msg = next ( self .coro )
292+ msg = self . _run_in_context ( next , self .coro )
286293 except StopIteration :
287294 break
288295
@@ -306,8 +313,8 @@ class MiddlePipelineThread(PipelineThread):
306313 last.
307314 """
308315
309- def __init__ (self , coro , in_queue , out_queue , all_threads ):
310- super ().__init__ (all_threads )
316+ def __init__ (self , coro , in_queue , out_queue , all_threads , ctx = None ):
317+ super ().__init__ (all_threads , ctx )
311318 self .coro = coro
312319 self .in_queue = in_queue
313320 self .out_queue = out_queue
@@ -316,7 +323,7 @@ def __init__(self, coro, in_queue, out_queue, all_threads):
316323 def run (self ):
317324 try :
318325 # Prime the coroutine.
319- next ( self .coro )
326+ self . _run_in_context ( next , self .coro )
320327
321328 while True :
322329 with self .abort_lock :
@@ -333,7 +340,7 @@ def run(self):
333340 return
334341
335342 # Invoke the current stage.
336- out = self .coro .send ( msg )
343+ out = self ._run_in_context ( self . coro .send , msg )
337344
338345 # Send messages to next stage.
339346 for msg in _allmsgs (out ):
@@ -355,14 +362,14 @@ class LastPipelineThread(PipelineThread):
355362 should yield nothing.
356363 """
357364
358- def __init__ (self , coro , in_queue , all_threads ):
359- super ().__init__ (all_threads )
365+ def __init__ (self , coro , in_queue , all_threads , ctx = None ):
366+ super ().__init__ (all_threads , ctx )
360367 self .coro = coro
361368 self .in_queue = in_queue
362369
363370 def run (self ):
364371 # Prime the coroutine.
365- next ( self .coro )
372+ self . _run_in_context ( next , self .coro )
366373
367374 try :
368375 while True :
@@ -380,7 +387,7 @@ def run(self):
380387 return
381388
382389 # Send to consumer.
383- self .coro .send ( msg )
390+ self ._run_in_context ( self . coro .send , msg )
384391
385392 except BaseException :
386393 self .abort_all (sys .exc_info ())
@@ -419,26 +426,37 @@ def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
419426 messages between the stages are stored in queues of the given
420427 size.
421428 """
429+ base_ctx = contextvars .copy_context ()
422430 queue_count = len (self .stages ) - 1
423431 queues = [CountedQueue (queue_size ) for i in range (queue_count )]
424432 threads = []
425433
426434 # Set up first stage.
427435 for coro in self .stages [0 ]:
428- threads .append (FirstPipelineThread (coro , queues [0 ], threads ))
436+ # Each worker needs its own copy because Context objects cannot be
437+ # entered concurrently from multiple threads.
438+ threads .append (
439+ FirstPipelineThread (coro , queues [0 ], threads , base_ctx .copy ())
440+ )
429441
430442 # Middle stages.
431443 for i in range (1 , queue_count ):
432444 for coro in self .stages [i ]:
433445 threads .append (
434446 MiddlePipelineThread (
435- coro , queues [i - 1 ], queues [i ], threads
447+ coro ,
448+ queues [i - 1 ],
449+ queues [i ],
450+ threads ,
451+ base_ctx .copy (),
436452 )
437453 )
438454
439455 # Last stage.
440456 for coro in self .stages [- 1 ]:
441- threads .append (LastPipelineThread (coro , queues [- 1 ], threads ))
457+ threads .append (
458+ LastPipelineThread (coro , queues [- 1 ], threads , base_ctx .copy ())
459+ )
442460
443461 # Start threads.
444462 for thread in threads :
0 commit comments