Skip to content

Commit b27f966

Browse files
committed
bpo-39190: Fix deadlock when callback raises
1 parent c3a651a commit b27f966

File tree

2 files changed

+62
-18
lines changed

2 files changed

+62
-18
lines changed

Lib/multiprocessing/pool.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,9 @@ def _handle_results(outqueue, get, cache):
592592
cache[job]._set(i, obj)
593593
except KeyError:
594594
pass
595+
except Exception:
596+
# Even if we raised we still want to handle callbacks
597+
traceback.print_exc()
595598
task = job = obj = None
596599

597600
while cache and thread._state != TERMINATE:
@@ -609,6 +612,9 @@ def _handle_results(outqueue, get, cache):
609612
cache[job]._set(i, obj)
610613
except KeyError:
611614
pass
615+
except Exception:
616+
# Even if we raised we still want to handle callbacks
617+
traceback.print_exc()
612618
task = job = obj = None
613619

614620
if hasattr(outqueue, '_reader'):
@@ -772,13 +778,15 @@ def get(self, timeout=None):
772778

773779
def _set(self, i, obj):
774780
self._success, self._value = obj
775-
if self._callback and self._success:
776-
self._callback(self._value)
777-
if self._error_callback and not self._success:
778-
self._error_callback(self._value)
779-
self._event.set()
780-
del self._cache[self._job]
781-
self._pool = None
781+
try:
782+
if self._callback and self._success:
783+
self._callback(self._value)
784+
if self._error_callback and not self._success:
785+
self._error_callback(self._value)
786+
finally:
787+
self._event.set()
788+
del self._cache[self._job]
789+
self._pool = None
782790

783791
__class_getitem__ = classmethod(types.GenericAlias)
784792

@@ -809,23 +817,27 @@ def _set(self, i, success_result):
809817
if success and self._success:
810818
self._value[i*self._chunksize:(i+1)*self._chunksize] = result
811819
if self._number_left == 0:
812-
if self._callback:
813-
self._callback(self._value)
814-
del self._cache[self._job]
815-
self._event.set()
816-
self._pool = None
820+
try:
821+
if self._callback:
822+
self._callback(self._value)
823+
finally:
824+
del self._cache[self._job]
825+
self._event.set()
826+
self._pool = None
817827
else:
818828
if not success and self._success:
819829
# only store first exception
820830
self._success = False
821831
self._value = result
822832
if self._number_left == 0:
823833
# only consider the result ready once all jobs are done
824-
if self._error_callback:
825-
self._error_callback(self._value)
826-
del self._cache[self._job]
827-
self._event.set()
828-
self._pool = None
834+
try:
835+
if self._error_callback:
836+
self._error_callback(self._value)
837+
finally:
838+
del self._cache[self._job]
839+
self._event.set()
840+
self._pool = None
829841

830842
#
831843
# Class whose instances are returned by `Pool.imap()`

Lib/test/_test_multiprocessing.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2741,6 +2741,39 @@ def errback(exc):
27412741
p.close()
27422742
p.join()
27432743

2744+
class _TestPoolResultHandlerErrors(BaseTestCase):
2745+
ALLOWED_TYPES = ('processes', )
2746+
2747+
def test_apply_async_callback_raises_exception(self):
2748+
p = multiprocessing.Pool(1)
2749+
2750+
def job():
2751+
return 1
2752+
2753+
def callback(value):
2754+
raise Exception()
2755+
2756+
p.apply_async(job, callback=callback)
2757+
2758+
self.assertTrue(p._result_handler.is_alive())
2759+
p.close()
2760+
p.join()
2761+
2762+
def test_map_async_callback_raises_exception(self):
2763+
p = multiprocessing.Pool(1)
2764+
2765+
def job(value):
2766+
return value
2767+
2768+
def callback(value):
2769+
raise Exception()
2770+
2771+
p.map_async(job, [1], callback=callback)
2772+
2773+
self.assertTrue(p._result_handler.is_alive())
2774+
p.close()
2775+
p.join()
2776+
27442777
class _TestPoolWorkerLifetime(BaseTestCase):
27452778
ALLOWED_TYPES = ('processes', )
27462779

@@ -5740,7 +5773,6 @@ def install_tests_in_module_dict(remote_globs, start_method):
57405773
__module__ = remote_globs['__name__']
57415774
local_globs = globals()
57425775
ALL_TYPES = {'processes', 'threads', 'manager'}
5743-
57445776
for name, base in local_globs.items():
57455777
if not isinstance(base, type):
57465778
continue

0 commit comments

Comments
 (0)