11import collections
22import logging
3- import threading
3+ from threading import Thread
44from copy import copy , deepcopy
55from collections import deque
66
@@ -153,21 +153,13 @@ def walk(self, walk_func):
153153 Args:
154154 walk_func (:class:`types.FunctionType`): The function to be called
155155 on each node of the graph.
156-
157- Returns:
158- bool: True if the function succeeded on every node, otherwise
159- False.
160156 """
161157 nodes = self .topological_sort ()
162158 # Reverse so we start with nodes that have no dependencies.
163159 nodes .reverse ()
164160
165- failed = False
166161 for n in nodes :
167- if not walk_func (n ):
168- failed = True
169-
170- return not failed
162+ walk_func (n )
171163
172164 def rename_edges (self , old_node_name , new_node_name ):
173165 """ Change references to a node in existing edges.
@@ -383,25 +375,6 @@ def release(self):
383375 pass
384376
385377
386- class Thread (threading .Thread ):
387- """Used when executing walk_func's in parallel, to provide access to the
388- return value.
389- """
390- def __init__ (self , * args , ** kwargs ):
391- super (Thread , self ).__init__ (* args , ** kwargs )
392- self ._return = None
393-
394- def run (self ):
395- if self ._Thread__target is not None :
396- self ._return = self ._Thread__target (
397- * self ._Thread__args ,
398- ** self ._Thread__kwargs )
399-
400- def join (self , * args , ** kwargs ):
401- super (Thread , self ).join (* args , ** kwargs )
402- return self ._return
403-
404-
405378class ThreadedWalker (object ):
406379 """A DAG walker that walks the graph as quickly as the graph topology
407380 allows, using threads.
@@ -437,14 +410,14 @@ def walk(self, dag, walk_func):
437410 # Blocks until all of the given nodes have completed execution (whether
438411 # successfully, or errored). Returns True if all nodes returned True.
439412 def wait_for (nodes ):
440- return all (threads [node ].join () for node in nodes )
413+ for node in nodes :
414+ thread = threads [node ]
415+ while thread .is_alive ():
416+ threads [node ].join (0.5 )
441417
442418 # For each node in the graph, we're going to allocate a thread to
443419 # execute. The thread will block executing walk_func, until all of the
444- # nodes dependencies have executed successfully.
445- #
446- # If any node fails for some reason (e.g. raising an exception), any
447- # downstream nodes will be cancelled.
420+ # nodes dependencies have executed.
448421 for node in nodes :
449422 def fn (n , deps ):
450423 if deps :
@@ -460,17 +433,10 @@ def fn(n, deps):
460433
461434 self .semaphore .acquire ()
462435 try :
463- ret = walk_func (n )
436+ return walk_func (n )
464437 finally :
465438 self .semaphore .release ()
466439
467- if ret :
468- logger .debug ("%s completed" , n )
469- else :
470- logger .debug ("%s failed" , n )
471-
472- return ret
473-
474440 deps = dag .all_downstreams (node )
475441 threads [node ] = Thread (target = fn , args = (node , deps ), name = node )
476442
@@ -479,4 +445,4 @@ def fn(n, deps):
479445 threads [node ].start ()
480446
481447 # Wait for all threads to complete executing.
482- return wait_for (nodes )
448+ wait_for (nodes )
0 commit comments