11from attrs import field , frozen
22from datetime import datetime
33from metafold .api import asdatetime , asdict , optional_datetime
4+ from metafold .assets import Asset
45from metafold .client import Client
56from metafold .exceptions import PollTimeout
7+ from metafold .jobs import Job
68from requests import Response
79from typing import Optional , Union
10+ import typing
11+
12+ if typing .TYPE_CHECKING :
13+ from metafold import MetafoldClient
814
915
1016@frozen (kw_only = True )
@@ -21,6 +27,9 @@ class Workflow:
2127 definition: Workflow definition string.
2228 project_id: Project ID.
2329 """
30+ _client : "MetafoldClient"
31+ _jobs : dict [str , str ] = field (factory = dict , init = False )
32+
2433 id : str
2534 jobs : list [str ] = field (factory = list )
2635 state : str
@@ -32,6 +41,52 @@ class Workflow:
3241 definition : str
3342 project_id : str
3443
44+ def get_asset (self , path : str ) -> Asset | None :
45+ """Retrieve an asset from the workflow by dot notation.
46+
47+ Args:
48+ path: Path to asset in the form "job.name", e.g. "sample-mesh.volume"
49+ searches for the asset "volume" from the "sample-mesh" job.
50+ """
51+ job_name , asset_name = self ._parse_path (path )
52+ job = self ._find_job (job_name )
53+ if not job or not job .outputs .assets :
54+ return
55+ for name , asset in job .outputs .assets .items ():
56+ if name == asset_name :
57+ return asset
58+
59+ def get_parameter (self , path : str ) -> str | None :
60+ """Retrieve a parameter from the workflow by dot notation.
61+
62+ Args:
63+ path: Path to parameter in the form "job.name", e.g. "sample-mesh.patch_size"
64+ searches for the parameter "patch_size" from the "sample-mesh" job.
65+ """
66+ job_name , param_name = self ._parse_path (path )
67+ job = self ._find_job (job_name )
68+ if not job or not job .outputs .params :
69+ return
70+ for name , param in job .outputs .params .items ():
71+ if name == param_name :
72+ return param
73+
74+ def _find_job (self , name : str ) -> Job | None :
75+ # FIXME(ryan): Update API to return job names as well as IDs.
76+ # For now we cache a mapping b/w job name and job id.
77+ if job_id := self ._jobs .get (name ):
78+ return self ._client .jobs .get (job_id )
79+
80+ for job_id in self .jobs :
81+ job = self ._client .jobs .get (job_id )
82+ if job .name == name :
83+ self ._jobs [name ] = job_id
84+ return job
85+
86+ @staticmethod
87+ def _parse_path (path : str ) -> tuple [str , str ]:
88+ return path .split ("." , maxsplit = 1 )
89+
3590
3691class WorkflowsEndpoint :
3792 """Metafold workflows endpoint."""
@@ -61,7 +116,7 @@ def list(
61116 url = f"/projects/{ project_id } /workflows"
62117 payload = asdict (sort = sort , q = q )
63118 r : Response = self ._client .get (url , params = payload )
64- return [Workflow (** w ) for w in r .json ()]
119+ return [Workflow (client = self . _client , ** w ) for w in r .json ()]
65120
66121 def get (self , workflow_id : str , project_id : Optional [str ] = None ) -> Workflow :
67122 """Get a workflow.
@@ -76,7 +131,7 @@ def get(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow:
76131 project_id = self ._client .project_id (project_id )
77132 url = f"/projects/{ project_id } /workflows/{ workflow_id } "
78133 r : Response = self ._client .get (url )
79- return Workflow (** r .json ())
134+ return Workflow (client = self . _client , ** r .json ())
80135
81136 def run (
82137 self , definition : str ,
@@ -110,7 +165,7 @@ def run(
110165 raise RuntimeError (
111166 f"Workflow failed to complete within { timeout } seconds"
112167 ) from e
113- return Workflow (** r .json ())
168+ return Workflow (client = self . _client , ** r .json ())
114169
115170 def cancel (self , workflow_id : str , project_id : Optional [str ] = None ) -> Workflow :
116171 """Cancel a running workflow.
@@ -125,7 +180,7 @@ def cancel(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow
125180 project_id = self ._client .project_id (project_id )
126181 url = f"/projects/{ project_id } /workflows/{ workflow_id } /cancel"
127182 r : Response = self ._client .post (url )
128- return Workflow (** r .json ())
183+ return Workflow (client = self . _client , ** r .json ())
129184
130185 def delete (self , workflow_id : str , project_id : Optional [str ] = None ):
131186 """Delete a workflow.
0 commit comments