Skip to content

Commit a40b196

Browse files
committed
Add concurrency to run_workflow.py
1 parent 2c62285 commit a40b196

1 file changed

Lines changed: 38 additions & 15 deletions

File tree

examples/run_workflow.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,38 @@
11
"""Helper script to dispatch workflows."""
22
from argparse import ArgumentParser, FileType
33
from metafold import MetafoldClient
4+
from multiprocessing import Pool
45
from pathlib import Path
56
from pprint import pprint
7+
from typing import Optional
68
import json
79
import os
810
import sys
911

1012

13+
def run_workflow(
14+
m: MetafoldClient,
15+
definition: str,
16+
assets: Optional[dict[str, str]] = None,
17+
params: Optional[dict[str, str]] = None,
18+
timeout: int = 5 * 60):
19+
print("Running workflow…")
20+
w = m.workflows.run(definition, assets=assets, parameters=params, timeout=timeout)
21+
22+
print(f"Workflow completed: {w.state}")
23+
24+
for job_id in w.jobs:
25+
j = m.jobs.get(job_id)
26+
match j.state:
27+
case "success":
28+
if j.outputs and j.outputs.assets:
29+
pprint(j.outputs.assets)
30+
if j.outputs and j.outputs.params:
31+
pprint(j.outputs.params)
32+
case "failure":
33+
print(f"Job {j.id} failed: {j.error}")
34+
35+
1136
def main():
1237
parser = ArgumentParser(description="Run workflow from YAML definition")
1338
parser.add_argument(
@@ -16,6 +41,8 @@ def main():
1641

1742
parser.add_argument("--assets", help="workflow asset mapping")
1843
parser.add_argument("--params", help="workflow parameter mapping")
44+
parser.add_argument("--count", type=int, default=1, help="number of repetitions")
45+
parser.add_argument("--num-concurrent", type=int, default=4, help="number of concurrent workflows")
1946

2047
parser.add_argument(
2148
"--asset-uploads", nargs="*",
@@ -65,22 +92,18 @@ def main():
6592
for p in args.asset_uploads:
6693
m.assets.create(p.resolve())
6794

68-
print("Running workflow…")
6995
definition = args.workflow.read()
70-
w = m.workflows.run(definition, assets=assets, parameters=params)
71-
72-
print(f"Workflow completed: {w.state}")
73-
74-
for job_id in w.jobs:
75-
j = m.jobs.get(job_id)
76-
match j.state:
77-
case "success":
78-
if j.outputs and j.outputs.assets:
79-
pprint(j.outputs.assets)
80-
if j.outputs and j.outputs.params:
81-
pprint(j.outputs.params)
82-
case "failure":
83-
print(f"Job {j.id} failed: {j.error}")
96+
if args.count > 1:
97+
num_concurrent = max(args.num_concurrent, 4)
98+
with Pool(num_concurrent) as p:
99+
for _ in range(args.count):
100+
p.apply_async(
101+
run_workflow, (m, definition), {"assets": assets, "params": params})
102+
p.close()
103+
p.join()
104+
105+
elif args.count == 1:
106+
run_workflow(m, definition, assets=assets, params=params)
84107

85108
return 0
86109

0 commit comments

Comments
 (0)