Skip to content

Commit 0b141bf

Browse files
authored
fix: API-level task/run input and output handling (#50)
* Fixes for task/run input and output handling. Add artifacts for runs * Add tree structuring to task and trace api functions
1 parent 72a72fc commit 0b141bf

5 files changed

Lines changed: 620 additions & 303 deletions

File tree

dreadnode/api/client.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,29 @@
77
from pydantic import BaseModel
88
from ulid import ULID
99

10+
from dreadnode.api.util import (
11+
convert_flat_tasks_to_tree,
12+
convert_flat_trace_to_tree,
13+
process_run,
14+
process_task,
15+
)
1016
from dreadnode.util import logger
1117
from dreadnode.version import VERSION
1218

1319
from .models import (
1420
MetricAggregationType,
1521
Project,
22+
RawRun,
23+
RawTask,
1624
Run,
25+
RunSummary,
1726
StatusFilter,
1827
Task,
28+
TaskTree,
1929
TimeAggregationType,
2030
TimeAxisType,
2131
TraceSpan,
32+
TraceTree,
2233
UserDataCredentials,
2334
)
2435

@@ -119,27 +130,63 @@ def get_project(self, project: str) -> Project:
119130
response = self.request("GET", f"/strikes/projects/{project!s}")
120131
return Project(**response.json())
121132

122-
def list_runs(self, project: str) -> list[Run]:
133+
def list_runs(self, project: str) -> list[RunSummary]:
123134
response = self.request("GET", f"/strikes/projects/{project!s}/runs")
124-
return [Run(**run) for run in response.json()]
135+
return [RunSummary(**run) for run in response.json()]
125136

126-
def get_run(self, run: str | ULID) -> Run:
137+
def _get_run(self, run: str | ULID) -> RawRun:
127138
response = self.request("GET", f"/strikes/projects/runs/{run!s}")
128-
return Run(**response.json())
139+
return RawRun(**response.json())
129140

130-
def get_run_tasks(self, run: str | ULID) -> list[Task]:
131-
response = self.request("GET", f"/strikes/projects/runs/{run!s}/tasks/full")
132-
return [Task(**task) for task in response.json()]
141+
def get_run(self, run: str | ULID) -> Run:
142+
return process_run(self._get_run(run))
143+
144+
TraceFormat = t.Literal["tree", "flat"]
133145

134-
def get_run_trace(self, run: str | ULID) -> list[Task | TraceSpan]:
146+
@t.overload
147+
def get_run_tasks(
148+
self, run: str | ULID, *, format: t.Literal["tree"] = "tree"
149+
) -> list[TaskTree]: ...
150+
151+
@t.overload
152+
def get_run_tasks(
153+
self, run: str | ULID, *, format: t.Literal["flat"] = "flat"
154+
) -> list[Task]: ...
155+
156+
def get_run_tasks(
157+
self, run: str | ULID, *, format: TraceFormat = "flat"
158+
) -> list[Task] | list[TaskTree]:
159+
raw_run = self._get_run(run)
160+
response = self.request("GET", f"/strikes/projects/runs/{run!s}/tasks/full")
161+
raw_tasks = [RawTask(**task) for task in response.json()]
162+
tasks = [process_task(task, raw_run) for task in raw_tasks]
163+
tasks = sorted(tasks, key=lambda x: x.timestamp)
164+
return tasks if format == "flat" else convert_flat_tasks_to_tree(tasks)
165+
166+
@t.overload
167+
def get_run_trace(
168+
self, run: str | ULID, *, format: t.Literal["tree"] = "tree"
169+
) -> list[TraceTree]: ...
170+
171+
@t.overload
172+
def get_run_trace(
173+
self, run: str | ULID, *, format: t.Literal["flat"] = "flat"
174+
) -> list[Task | TraceSpan]: ...
175+
176+
def get_run_trace(
177+
self, run: str | ULID, *, format: TraceFormat = "flat"
178+
) -> list[Task | TraceSpan] | list[TraceTree]:
179+
raw_run = self._get_run(run)
135180
response = self.request("GET", f"/strikes/projects/runs/{run!s}/spans/full")
136-
spans: list[Task | TraceSpan] = []
181+
trace: list[Task | TraceSpan] = []
137182
for item in response.json():
138183
if "parent_task_span_id" in item:
139-
spans.append(Task(**item))
184+
trace.append(process_task(RawTask(**item), raw_run))
140185
else:
141-
spans.append(TraceSpan(**item))
142-
return spans
186+
trace.append(TraceSpan(**item))
187+
188+
trace = sorted(trace, key=lambda x: x.timestamp)
189+
return trace if format == "flat" else convert_flat_trace_to_tree(trace)
143190

144191
# Data exports
145192

dreadnode/api/models.py

Lines changed: 135 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
1+
import contextlib
12
import typing as t
23
from datetime import datetime
4+
from functools import cached_property
35
from uuid import UUID
46

5-
from pydantic import BaseModel, Field
7+
import requests
8+
from pydantic import (
9+
BaseModel,
10+
ConfigDict,
11+
Field,
12+
PrivateAttr,
13+
TypeAdapter,
14+
ValidationError,
15+
field_validator,
16+
)
617
from ulid import ULID
718

819
AnyDict = dict[str, t.Any]
@@ -79,17 +90,17 @@ class TraceLog(BaseModel):
7990
class TraceSpan(BaseModel):
8091
timestamp: datetime
8192
duration: int
82-
trace_id: str
93+
trace_id: str = Field(repr=False)
8394
span_id: str
84-
parent_span_id: str | None
85-
service_name: str | None
95+
parent_span_id: str | None = Field(repr=False)
96+
service_name: str | None = Field(repr=False)
8697
status: SpanStatus
8798
exception: SpanException | None
8899
name: str
89-
attributes: AnyDict
90-
resource_attributes: AnyDict
91-
events: list[SpanEvent]
92-
links: list[SpanLink]
100+
attributes: AnyDict = Field(repr=False)
101+
resource_attributes: AnyDict = Field(repr=False)
102+
events: list[SpanEvent] = Field(repr=False)
103+
links: list[SpanLink] = Field(repr=False)
93104

94105

95106
class Metric(BaseModel):
@@ -105,22 +116,22 @@ class ObjectRef(BaseModel):
105116
hash: str
106117

107118

108-
class ObjectUri(BaseModel):
119+
class RawObjectUri(BaseModel):
109120
hash: str
110121
schema_hash: str
111122
uri: str
112123
size: int
113124
type: t.Literal["uri"]
114125

115126

116-
class ObjectVal(BaseModel):
127+
class RawObjectVal(BaseModel):
117128
hash: str
118129
schema_hash: str
119130
value: t.Any
120131
type: t.Literal["val"]
121132

122133

123-
Object = ObjectUri | ObjectVal
134+
RawObject = RawObjectUri | RawObjectVal
124135

125136

126137
class V0Object(BaseModel):
@@ -129,56 +140,141 @@ class V0Object(BaseModel):
129140
value: t.Any
130141

131142

132-
class Run(BaseModel):
143+
class ObjectVal(BaseModel):
144+
model_config = ConfigDict(arbitrary_types_allowed=True)
145+
146+
name: str
147+
label: str
148+
hash: str = Field(repr=False)
149+
schema_: AnyDict
150+
schema_hash: str = Field(repr=False)
151+
value: t.Any
152+
153+
@field_validator("value")
154+
@classmethod
155+
def validate_value(cls, value: t.Any) -> t.Any:
156+
if isinstance(value, str):
157+
with contextlib.suppress(ValidationError):
158+
return TypeAdapter(t.Any).validate_json(value)
159+
160+
return value
161+
162+
163+
class ObjectUri(BaseModel):
164+
name: str
165+
label: str
166+
hash: str = Field(repr=False)
167+
schema_: AnyDict
168+
schema_hash: str = Field(repr=False)
169+
uri: str
170+
size: int
171+
172+
_value: t.Any = PrivateAttr(default=None)
173+
174+
@cached_property
175+
def value(self) -> t.Any:
176+
if self._value is not None:
177+
return self._value
178+
179+
try:
180+
response = requests.get(self.uri, timeout=5)
181+
response.raise_for_status()
182+
self._value = response.text
183+
except requests.RequestException as e:
184+
raise RuntimeError(f"Failed to fetch object from {self.uri}") from e
185+
186+
if isinstance(self._value, str):
187+
with contextlib.suppress(ValidationError):
188+
self._value = TypeAdapter(t.Any).validate_json(self._value)
189+
190+
return self._value
191+
192+
193+
Object = ObjectVal | ObjectUri
194+
195+
196+
class ArtifactFile(BaseModel):
197+
hash: str
198+
uri: str
199+
size_bytes: int
200+
final_real_path: str
201+
202+
203+
class ArtifactDir(BaseModel):
204+
dir_path: str
205+
hash: str
206+
children: list[t.Union["ArtifactDir", ArtifactFile]]
207+
208+
209+
class RunSummary(BaseModel):
133210
id: ULID
134211
name: str
135-
span_id: str
136-
trace_id: str
212+
span_id: str = Field(repr=False)
213+
trace_id: str = Field(repr=False)
137214
timestamp: datetime
138215
duration: int
139216
status: SpanStatus
140217
exception: SpanException | None
141218
tags: set[str]
142-
params: AnyDict
143-
metrics: dict[str, list[Metric]]
144-
inputs: list[ObjectRef]
145-
outputs: list[ObjectRef]
146-
objects: dict[str, Object]
147-
object_schemas: AnyDict
148-
schema_: AnyDict = Field(alias="schema")
219+
params: AnyDict = Field(repr=False)
220+
metrics: dict[str, list[Metric]] = Field(repr=False)
149221

150222

151-
class Task(BaseModel):
223+
class RawRun(RunSummary):
224+
inputs: list[ObjectRef] = Field(repr=False)
225+
outputs: list[ObjectRef] = Field(repr=False)
226+
objects: dict[str, RawObject] = Field(repr=False)
227+
object_schemas: AnyDict = Field(repr=False)
228+
artifacts: list[ArtifactDir] = Field(repr=False)
229+
schema_: AnyDict = Field(alias="schema", repr=False)
230+
231+
232+
class Run(RunSummary):
233+
inputs: dict[str, Object] = Field(repr=False)
234+
outputs: dict[str, Object] = Field(repr=False)
235+
artifacts: list[ArtifactDir] = Field(repr=False)
236+
schema_: AnyDict = Field(alias="schema", repr=False)
237+
238+
239+
class _Task(BaseModel):
152240
name: str
153241
span_id: str
154-
trace_id: str
155-
parent_span_id: str | None
156-
parent_task_span_id: str | None
242+
trace_id: str = Field(repr=False)
243+
parent_span_id: str | None = Field(repr=False)
244+
parent_task_span_id: str | None = Field(repr=False)
157245
timestamp: datetime
158246
duration: int
159247
status: SpanStatus
160248
exception: SpanException | None
161249
tags: set[str]
162-
params: AnyDict
163-
metrics: dict[str, list[Metric]]
164-
inputs: list[ObjectRef] | list[V0Object] # v0 compat
165-
outputs: list[ObjectRef] | list[V0Object] # v0 compat
166-
schema_: AnyDict = Field(alias="schema")
167-
attributes: AnyDict
168-
resource_attributes: AnyDict
169-
events: list[SpanEvent]
170-
links: list[SpanLink]
250+
params: AnyDict = Field(repr=False)
251+
metrics: dict[str, list[Metric]] = Field(repr=False)
252+
schema_: AnyDict = Field(alias="schema", repr=False)
253+
attributes: AnyDict = Field(repr=False)
254+
resource_attributes: AnyDict = Field(repr=False)
255+
events: list[SpanEvent] = Field(repr=False)
256+
links: list[SpanLink] = Field(repr=False)
257+
258+
259+
class RawTask(_Task):
260+
inputs: list[ObjectRef] | list[V0Object] = Field(repr=False)
261+
outputs: list[ObjectRef] | list[V0Object] = Field(repr=False)
262+
263+
264+
class Task(_Task):
265+
inputs: dict[str, Object] = Field(repr=False)
266+
outputs: dict[str, Object] = Field(repr=False)
171267

172268

173269
class Project(BaseModel):
174-
id: UUID
270+
id: UUID = Field(repr=False)
175271
key: str
176272
name: str
177-
description: str | None
273+
description: str | None = Field(repr=False)
178274
created_at: datetime
179275
updated_at: datetime
180276
run_count: int
181-
last_run: Run | None
277+
last_run: RawRun | None = Field(repr=False)
182278

183279

184280
# Derived types
@@ -189,11 +285,9 @@ class TaskTree(BaseModel):
189285
children: list["TaskTree"] = []
190286

191287

192-
class SpanTree(BaseModel):
193-
"""Tree representation of a trace span with its children"""
194-
288+
class TraceTree(BaseModel):
195289
span: Task | TraceSpan
196-
children: list["SpanTree"] = []
290+
children: list["TraceTree"] = []
197291

198292

199293
# User data credentials

0 commit comments

Comments
 (0)