77from pydantic import BaseModel
88from ulid import ULID
99
10- from dreadnode .api .util import process_run , process_task
10+ from dreadnode .api .util import (
11+ convert_flat_tasks_to_tree ,
12+ convert_flat_trace_to_tree ,
13+ process_run ,
14+ process_task ,
15+ )
1116from dreadnode .util import logger
1217from dreadnode .version import VERSION
1318
2025 RunSummary ,
2126 StatusFilter ,
2227 Task ,
28+ TaskTree ,
2329 TimeAggregationType ,
2430 TimeAxisType ,
2531 TraceSpan ,
32+ TraceTree ,
2633 UserDataCredentials ,
2734)
2835
@@ -134,13 +141,41 @@ def _get_run(self, run: str | ULID) -> RawRun:
134141 def get_run (self , run : str | ULID ) -> Run :
135142 return process_run (self ._get_run (run ))
136143
137- def get_run_tasks (self , run : str | ULID ) -> list [Task ]:
144+ TraceFormat = t .Literal ["tree" , "flat" ]
145+
146+ @t .overload
147+ def get_run_tasks (
148+ self , run : str | ULID , * , format : t .Literal ["tree" ] = "tree"
149+ ) -> list [TaskTree ]: ...
150+
151+ @t .overload
152+ def get_run_tasks (
153+ self , run : str | ULID , * , format : t .Literal ["flat" ] = "flat"
154+ ) -> list [Task ]: ...
155+
156+ def get_run_tasks (
157+ self , run : str | ULID , * , format : TraceFormat = "flat"
158+ ) -> list [Task ] | list [TaskTree ]:
138159 raw_run = self ._get_run (run )
139160 response = self .request ("GET" , f"/strikes/projects/runs/{ run !s} /tasks/full" )
140161 raw_tasks = [RawTask (** task ) for task in response .json ()]
141- return [process_task (task , raw_run ) for task in raw_tasks ]
142-
143- def get_run_trace (self , run : str | ULID ) -> list [Task | TraceSpan ]:
162+ tasks = [process_task (task , raw_run ) for task in raw_tasks ]
163+ tasks = sorted (tasks , key = lambda x : x .timestamp )
164+ return tasks if format == "flat" else convert_flat_tasks_to_tree (tasks )
165+
166+ @t .overload
167+ def get_run_trace (
168+ self , run : str | ULID , * , format : t .Literal ["tree" ] = "tree"
169+ ) -> list [TraceTree ]: ...
170+
171+ @t .overload
172+ def get_run_trace (
173+ self , run : str | ULID , * , format : t .Literal ["flat" ] = "flat"
174+ ) -> list [Task | TraceSpan ]: ...
175+
176+ def get_run_trace (
177+ self , run : str | ULID , * , format : TraceFormat = "flat"
178+ ) -> list [Task | TraceSpan ] | list [TraceTree ]:
144179 raw_run = self ._get_run (run )
145180 response = self .request ("GET" , f"/strikes/projects/runs/{ run !s} /spans/full" )
146181 trace : list [Task | TraceSpan ] = []
@@ -149,7 +184,9 @@ def get_run_trace(self, run: str | ULID) -> list[Task | TraceSpan]:
149184 trace .append (process_task (RawTask (** item ), raw_run ))
150185 else :
151186 trace .append (TraceSpan (** item ))
152- return trace
187+
188+ trace = sorted (trace , key = lambda x : x .timestamp )
189+ return trace if format == "flat" else convert_flat_trace_to_tree (trace )
153190
154191 # Data exports
155192
0 commit comments