@@ -42,15 +42,21 @@ def __init__(
4242 self .agg_mode = agg_mode
4343
4444
45- def op (name : str , deps = None ):
45+ def op (name : str , deps = None , agg_mode : AggMode = AggMode . ALL_REDUCE ):
4646 deps = deps or []
4747
4848 def decorator (func ):
4949 @wraps (func )
5050 def _wrapper (* args , ** kwargs ):
5151 return func (* args , ** kwargs )
5252
53- _wrapper .op_node = OpNode (name , deps , lambda self , ctx : func (self , ** ctx ))
53+ _wrapper .op_node = OpNode (
54+ name = name ,
55+ deps = deps ,
56+ compute_func = lambda self , ctx : func (self ),
57+ callback_func = lambda self , ctx , results : None ,
58+ agg_mode = agg_mode ,
59+ )
5460 return _wrapper
5561
5662 return decorator
@@ -59,48 +65,45 @@ def _wrapper(*args, **kwargs):
5965class Engine :
6066 def __init__ (self , max_workers : int = 4 ):
6167 self .max_workers = max_workers
68+ self .bucket_mgr = BucketManager ()
6269
6370 def run (self , ops : List [OpNode ], ctx : Context ):
64- name2op = {operation .name : operation for operation in ops }
65-
66- # topological sort
67- graph = {n : set (name2op [n ].deps ) for n in name2op }
68- topo = []
69- q = [n for n , d in graph .items () if not d ]
70- while q :
71- cur = q .pop (0 )
72- topo .append (cur )
73- for child in [c for c , d in graph .items () if cur in d ]:
74- graph [child ].remove (cur )
75- if not graph [child ]:
76- q .append (child )
71+ name2op = {op .name : op for op in ops }
72+ topo_names = [op .name for op in self ._topo_sort (ops )]
7773
78- if len (topo ) != len (ops ):
79- raise ValueError (
80- "Cyclic dependencies detected among operations."
81- "Please check your configuration."
82- )
83-
84- # semaphore for max_workers
8574 sem = threading .Semaphore (self .max_workers )
8675 done = {n : threading .Event () for n in name2op }
8776 exc = {}
8877
78+ for node in ops :
79+ bucket_size = ctx .get (f"_bucket_size_{ node .name } " , 1 )
80+ self .bucket_mgr .register (
81+ node .name ,
82+ bucket_size ,
83+ node .agg_mode ,
84+ lambda results , n = node : self ._callback_wrapper (n , ctx , results ),
85+ )
86+
8987 def _exec (n : str ):
9088 with sem :
9189 for d in name2op [n ].deps :
9290 done [d ].wait ()
9391 if any (d in exc for d in name2op [n ].deps ):
94- exc [n ] = Exception ( "Skipped due to failed dependencies" )
92+ exc [n ] = "Skipped due to failed dependencies"
9593 done [n ].set ()
9694 return
95+
9796 try :
98- name2op [n ].func (name2op [n ], ctx )
97+ name2op [n ].compute_func (name2op [n ], ctx )
9998 except Exception : # pylint: disable=broad-except
10099 exc [n ] = traceback .format_exc ()
101- done [n ].set ()
100+ finally :
101+ done [n ].set ()
102102
103- ts = [threading .Thread (target = _exec , args = (n ,), daemon = True ) for n in topo ]
103+ ts = [
104+ threading .Thread (target = _exec , args = (name ,), daemon = True )
105+ for name in topo_names
106+ ]
104107 for t in ts :
105108 t .start ()
106109 for t in ts :
@@ -111,6 +114,34 @@ def _exec(n: str):
111114 + "\n " .join (f"---- { op } ----\n { tb } " for op , tb in exc .items ())
112115 )
113116
117+ @staticmethod
118+ def _callback_wrapper (node : OpNode , ctx : Context , results : List [Any ]):
119+ try :
120+ node .callback_func (node , ctx , results )
121+ except Exception : # pylint: disable=broad-except
122+ traceback .print_exc ()
123+
124+ @staticmethod
125+ def _topo_sort (ops : List [OpNode ]) -> List [OpNode ]:
126+ name2op = {operation .name : operation for operation in ops }
127+ graph = {n : set (name2op [n ].deps ) for n in name2op }
128+ topo = []
129+ q = [n for n , d in graph .items () if not d ]
130+ while q :
131+ cur = q .pop (0 )
132+ topo .append (name2op [cur ])
133+ for child in [c for c , d in graph .items () if cur in d ]:
134+ graph [child ].remove (cur )
135+ if not graph [child ]:
136+ q .append (child )
137+
138+ if len (topo ) != len (ops ):
139+ raise ValueError (
140+ "Cyclic dependencies detected among operations."
141+ "Please check your configuration."
142+ )
143+ return topo
144+
114145
115146class Bucket :
116147 """
@@ -190,8 +221,9 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]:
190221 op_node .deps = runtime_deps
191222
192223 if "params" in stage :
193- op_node .func = lambda self , ctx , m = method , sc = stage : m (sc .get ("params" , {}))
224+ params = stage ["params" ]
225+ op_node .compute_func = lambda self , ctx , m = method , p = params : m (p )
194226 else :
195- op_node .func = lambda self , ctx , m = method : m ()
227+ op_node .compute_func = lambda self , ctx , m = method : m ()
196228 ops .append (op_node )
197229 return ops
0 commit comments