Skip to content

Commit 3a246e0

Browse files
committed
Add tree structuring to task and trace api functions
1 parent e99c704 commit 3a246e0

3 files changed

Lines changed: 82 additions & 10 deletions

File tree

dreadnode/api/client.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from pydantic import BaseModel
88
from ulid import ULID
99

10-
from dreadnode.api.util import process_run, process_task
10+
from dreadnode.api.util import (
11+
convert_flat_tasks_to_tree,
12+
convert_flat_trace_to_tree,
13+
process_run,
14+
process_task,
15+
)
1116
from dreadnode.util import logger
1217
from dreadnode.version import VERSION
1318

@@ -20,9 +25,11 @@
2025
RunSummary,
2126
StatusFilter,
2227
Task,
28+
TaskTree,
2329
TimeAggregationType,
2430
TimeAxisType,
2531
TraceSpan,
32+
TraceTree,
2633
UserDataCredentials,
2734
)
2835

@@ -134,13 +141,41 @@ def _get_run(self, run: str | ULID) -> RawRun:
134141
def get_run(self, run: str | ULID) -> Run:
135142
return process_run(self._get_run(run))
136143

137-
def get_run_tasks(self, run: str | ULID) -> list[Task]:
144+
TraceFormat = t.Literal["tree", "flat"]
145+
146+
@t.overload
147+
def get_run_tasks(
148+
self, run: str | ULID, *, format: t.Literal["tree"] = "tree"
149+
) -> list[TaskTree]: ...
150+
151+
@t.overload
152+
def get_run_tasks(
153+
self, run: str | ULID, *, format: t.Literal["flat"] = "flat"
154+
) -> list[Task]: ...
155+
156+
def get_run_tasks(
157+
self, run: str | ULID, *, format: TraceFormat = "flat"
158+
) -> list[Task] | list[TaskTree]:
138159
raw_run = self._get_run(run)
139160
response = self.request("GET", f"/strikes/projects/runs/{run!s}/tasks/full")
140161
raw_tasks = [RawTask(**task) for task in response.json()]
141-
return [process_task(task, raw_run) for task in raw_tasks]
142-
143-
def get_run_trace(self, run: str | ULID) -> list[Task | TraceSpan]:
162+
tasks = [process_task(task, raw_run) for task in raw_tasks]
163+
tasks = sorted(tasks, key=lambda x: x.timestamp)
164+
return tasks if format == "flat" else convert_flat_tasks_to_tree(tasks)
165+
166+
@t.overload
167+
def get_run_trace(
168+
self, run: str | ULID, *, format: t.Literal["tree"] = "tree"
169+
) -> list[TraceTree]: ...
170+
171+
@t.overload
172+
def get_run_trace(
173+
self, run: str | ULID, *, format: t.Literal["flat"] = "flat"
174+
) -> list[Task | TraceSpan]: ...
175+
176+
def get_run_trace(
177+
self, run: str | ULID, *, format: TraceFormat = "flat"
178+
) -> list[Task | TraceSpan] | list[TraceTree]:
144179
raw_run = self._get_run(run)
145180
response = self.request("GET", f"/strikes/projects/runs/{run!s}/spans/full")
146181
trace: list[Task | TraceSpan] = []
@@ -149,7 +184,9 @@ def get_run_trace(self, run: str | ULID) -> list[Task | TraceSpan]:
149184
trace.append(process_task(RawTask(**item), raw_run))
150185
else:
151186
trace.append(TraceSpan(**item))
152-
return trace
187+
188+
trace = sorted(trace, key=lambda x: x.timestamp)
189+
return trace if format == "flat" else convert_flat_trace_to_tree(trace)
153190

154191
# Data exports
155192

dreadnode/api/models.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,9 @@ class TaskTree(BaseModel):
285285
children: list["TaskTree"] = []
286286

287287

288-
class SpanTree(BaseModel):
289-
"""Tree representation of a trace span with its children"""
290-
288+
class TraceTree(BaseModel):
291289
span: Task | TraceSpan
292-
children: list["SpanTree"] = []
290+
children: list["TraceTree"] = []
293291

294292

295293
# User data credentials

dreadnode/api/util.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,49 @@
99
RawTask,
1010
Run,
1111
Task,
12+
TaskTree,
13+
TraceSpan,
14+
TraceTree,
1215
V0Object,
1316
)
1417

1518
logger = getLogger(__name__)
1619

1720

21+
def convert_flat_tasks_to_tree(tasks: list[Task]) -> list[TaskTree]:
22+
tree_nodes: dict[str, TaskTree] = {}
23+
for task in tasks:
24+
tree_nodes[task.span_id] = TaskTree(task=task)
25+
26+
roots: list[TaskTree] = []
27+
for task in tasks:
28+
if task.parent_task_span_id not in tree_nodes:
29+
roots.append(tree_nodes[task.span_id])
30+
else:
31+
parent_node = tree_nodes.get(task.parent_task_span_id)
32+
if parent_node:
33+
parent_node.children.append(tree_nodes[task.span_id])
34+
35+
return roots
36+
37+
38+
def convert_flat_trace_to_tree(trace: list[Task | TraceSpan]) -> list[TraceTree]:
39+
tree_nodes: dict[str, TraceTree] = {}
40+
for span in trace:
41+
tree_nodes[span.span_id] = TraceTree(span=span)
42+
43+
roots: list[TraceTree] = []
44+
for span in trace:
45+
if span.parent_span_id not in tree_nodes:
46+
roots.append(tree_nodes[span.span_id])
47+
else:
48+
parent_node = tree_nodes.get(span.parent_span_id)
49+
if parent_node:
50+
parent_node.children.append(tree_nodes[span.span_id])
51+
52+
return roots
53+
54+
1855
def process_run(run: RawRun) -> Run:
1956
inputs: dict[str, Object] = {}
2057
outputs: dict[str, Object] = {}

0 commit comments

Comments
 (0)