Skip to content

Commit 3927239

Browse files
fix: fix params in collect_ops
1 parent 8d7f2c5 commit 3927239

1 file changed

Lines changed: 60 additions & 28 deletions

File tree

graphgen/engine.py

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,21 @@ def __init__(
4242
self.agg_mode = agg_mode
4343

4444

45-
def op(name: str, deps=None):
45+
def op(name: str, deps=None, agg_mode: AggMode = AggMode.ALL_REDUCE):
4646
deps = deps or []
4747

4848
def decorator(func):
4949
@wraps(func)
5050
def _wrapper(*args, **kwargs):
5151
return func(*args, **kwargs)
5252

53-
_wrapper.op_node = OpNode(name, deps, lambda self, ctx: func(self, **ctx))
53+
_wrapper.op_node = OpNode(
54+
name=name,
55+
deps=deps,
56+
compute_func=lambda self, ctx: func(self),
57+
callback_func=lambda self, ctx, results: None,
58+
agg_mode=agg_mode,
59+
)
5460
return _wrapper
5561

5662
return decorator
@@ -59,48 +65,45 @@ def _wrapper(*args, **kwargs):
5965
class Engine:
6066
def __init__(self, max_workers: int = 4):
6167
self.max_workers = max_workers
68+
self.bucket_mgr = BucketManager()
6269

6370
def run(self, ops: List[OpNode], ctx: Context):
64-
name2op = {operation.name: operation for operation in ops}
65-
66-
# topological sort
67-
graph = {n: set(name2op[n].deps) for n in name2op}
68-
topo = []
69-
q = [n for n, d in graph.items() if not d]
70-
while q:
71-
cur = q.pop(0)
72-
topo.append(cur)
73-
for child in [c for c, d in graph.items() if cur in d]:
74-
graph[child].remove(cur)
75-
if not graph[child]:
76-
q.append(child)
71+
name2op = {op.name: op for op in ops}
72+
topo_names = [op.name for op in self._topo_sort(ops)]
7773

78-
if len(topo) != len(ops):
79-
raise ValueError(
80-
"Cyclic dependencies detected among operations."
81-
"Please check your configuration."
82-
)
83-
84-
# semaphore for max_workers
8574
sem = threading.Semaphore(self.max_workers)
8675
done = {n: threading.Event() for n in name2op}
8776
exc = {}
8877

78+
for node in ops:
79+
bucket_size = ctx.get(f"_bucket_size_{node.name}", 1)
80+
self.bucket_mgr.register(
81+
node.name,
82+
bucket_size,
83+
node.agg_mode,
84+
lambda results, n=node: self._callback_wrapper(n, ctx, results),
85+
)
86+
8987
def _exec(n: str):
9088
with sem:
9189
for d in name2op[n].deps:
9290
done[d].wait()
9391
if any(d in exc for d in name2op[n].deps):
94-
exc[n] = Exception("Skipped due to failed dependencies")
92+
exc[n] = "Skipped due to failed dependencies"
9593
done[n].set()
9694
return
95+
9796
try:
98-
name2op[n].func(name2op[n], ctx)
97+
name2op[n].compute_func(name2op[n], ctx)
9998
except Exception: # pylint: disable=broad-except
10099
exc[n] = traceback.format_exc()
101-
done[n].set()
100+
finally:
101+
done[n].set()
102102

103-
ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo]
103+
ts = [
104+
threading.Thread(target=_exec, args=(name,), daemon=True)
105+
for name in topo_names
106+
]
104107
for t in ts:
105108
t.start()
106109
for t in ts:
@@ -111,6 +114,34 @@ def _exec(n: str):
111114
+ "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items())
112115
)
113116

117+
@staticmethod
118+
def _callback_wrapper(node: OpNode, ctx: Context, results: List[Any]):
119+
try:
120+
node.callback_func(node, ctx, results)
121+
except Exception: # pylint: disable=broad-except
122+
traceback.print_exc()
123+
124+
@staticmethod
125+
def _topo_sort(ops: List[OpNode]) -> List[OpNode]:
126+
name2op = {operation.name: operation for operation in ops}
127+
graph = {n: set(name2op[n].deps) for n in name2op}
128+
topo = []
129+
q = [n for n, d in graph.items() if not d]
130+
while q:
131+
cur = q.pop(0)
132+
topo.append(name2op[cur])
133+
for child in [c for c, d in graph.items() if cur in d]:
134+
graph[child].remove(cur)
135+
if not graph[child]:
136+
q.append(child)
137+
138+
if len(topo) != len(ops):
139+
raise ValueError(
140+
"Cyclic dependencies detected among operations."
141+
"Please check your configuration."
142+
)
143+
return topo
144+
114145

115146
class Bucket:
116147
"""
@@ -190,8 +221,9 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]:
190221
op_node.deps = runtime_deps
191222

192223
if "params" in stage:
193-
op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params", {}))
224+
params = stage["params"]
225+
op_node.compute_func = lambda self, ctx, m=method, p=params: m(p)
194226
else:
195-
op_node.func = lambda self, ctx, m=method: m()
227+
op_node.compute_func = lambda self, ctx, m=method: m()
196228
ops.append(op_node)
197229
return ops

0 commit comments

Comments
 (0)