Skip to content

Commit e99c704

Browse files
committed
Fixes for task/run input and output handling. Add artifacts for runs
1 parent 72a72fc commit e99c704

5 files changed

Lines changed: 541 additions & 296 deletions

File tree

dreadnode/api/client.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
from pydantic import BaseModel
88
from ulid import ULID
99

10+
from dreadnode.api.util import process_run, process_task
1011
from dreadnode.util import logger
1112
from dreadnode.version import VERSION
1213

1314
from .models import (
1415
MetricAggregationType,
1516
Project,
17+
RawRun,
18+
RawTask,
1619
Run,
20+
RunSummary,
1721
StatusFilter,
1822
Task,
1923
TimeAggregationType,
@@ -119,27 +123,33 @@ def get_project(self, project: str) -> Project:
119123
response = self.request("GET", f"/strikes/projects/{project!s}")
120124
return Project(**response.json())
121125

122-
def list_runs(self, project: str) -> list[Run]:
126+
def list_runs(self, project: str) -> list[RunSummary]:
123127
response = self.request("GET", f"/strikes/projects/{project!s}/runs")
124-
return [Run(**run) for run in response.json()]
128+
return [RunSummary(**run) for run in response.json()]
125129

126-
def get_run(self, run: str | ULID) -> Run:
130+
def _get_run(self, run: str | ULID) -> RawRun:
127131
response = self.request("GET", f"/strikes/projects/runs/{run!s}")
128-
return Run(**response.json())
132+
return RawRun(**response.json())
133+
134+
def get_run(self, run: str | ULID) -> Run:
135+
return process_run(self._get_run(run))
129136

130137
def get_run_tasks(self, run: str | ULID) -> list[Task]:
138+
raw_run = self._get_run(run)
131139
response = self.request("GET", f"/strikes/projects/runs/{run!s}/tasks/full")
132-
return [Task(**task) for task in response.json()]
140+
raw_tasks = [RawTask(**task) for task in response.json()]
141+
return [process_task(task, raw_run) for task in raw_tasks]
133142

134143
def get_run_trace(self, run: str | ULID) -> list[Task | TraceSpan]:
144+
raw_run = self._get_run(run)
135145
response = self.request("GET", f"/strikes/projects/runs/{run!s}/spans/full")
136-
spans: list[Task | TraceSpan] = []
146+
trace: list[Task | TraceSpan] = []
137147
for item in response.json():
138148
if "parent_task_span_id" in item:
139-
spans.append(Task(**item))
149+
trace.append(process_task(RawTask(**item), raw_run))
140150
else:
141-
spans.append(TraceSpan(**item))
142-
return spans
151+
trace.append(TraceSpan(**item))
152+
return trace
143153

144154
# Data exports
145155

dreadnode/api/models.py

Lines changed: 133 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
1+
import contextlib
12
import typing as t
23
from datetime import datetime
4+
from functools import cached_property
35
from uuid import UUID
46

5-
from pydantic import BaseModel, Field
7+
import requests
8+
from pydantic import (
9+
BaseModel,
10+
ConfigDict,
11+
Field,
12+
PrivateAttr,
13+
TypeAdapter,
14+
ValidationError,
15+
field_validator,
16+
)
617
from ulid import ULID
718

819
AnyDict = dict[str, t.Any]
@@ -79,17 +90,17 @@ class TraceLog(BaseModel):
7990
class TraceSpan(BaseModel):
8091
timestamp: datetime
8192
duration: int
82-
trace_id: str
93+
trace_id: str = Field(repr=False)
8394
span_id: str
84-
parent_span_id: str | None
85-
service_name: str | None
95+
parent_span_id: str | None = Field(repr=False)
96+
service_name: str | None = Field(repr=False)
8697
status: SpanStatus
8798
exception: SpanException | None
8899
name: str
89-
attributes: AnyDict
90-
resource_attributes: AnyDict
91-
events: list[SpanEvent]
92-
links: list[SpanLink]
100+
attributes: AnyDict = Field(repr=False)
101+
resource_attributes: AnyDict = Field(repr=False)
102+
events: list[SpanEvent] = Field(repr=False)
103+
links: list[SpanLink] = Field(repr=False)
93104

94105

95106
class Metric(BaseModel):
@@ -105,22 +116,22 @@ class ObjectRef(BaseModel):
105116
hash: str
106117

107118

108-
class ObjectUri(BaseModel):
119+
class RawObjectUri(BaseModel):
109120
hash: str
110121
schema_hash: str
111122
uri: str
112123
size: int
113124
type: t.Literal["uri"]
114125

115126

116-
class ObjectVal(BaseModel):
127+
class RawObjectVal(BaseModel):
117128
hash: str
118129
schema_hash: str
119130
value: t.Any
120131
type: t.Literal["val"]
121132

122133

123-
Object = ObjectUri | ObjectVal
134+
RawObject = RawObjectUri | RawObjectVal
124135

125136

126137
class V0Object(BaseModel):
@@ -129,56 +140,141 @@ class V0Object(BaseModel):
129140
value: t.Any
130141

131142

132-
class Run(BaseModel):
143+
class ObjectVal(BaseModel):
144+
model_config = ConfigDict(arbitrary_types_allowed=True)
145+
146+
name: str
147+
label: str
148+
hash: str = Field(repr=False)
149+
schema_: AnyDict
150+
schema_hash: str = Field(repr=False)
151+
value: t.Any
152+
153+
@field_validator("value")
154+
@classmethod
155+
def validate_value(cls, value: t.Any) -> t.Any:
156+
if isinstance(value, str):
157+
with contextlib.suppress(ValidationError):
158+
return TypeAdapter(t.Any).validate_json(value)
159+
160+
return value
161+
162+
163+
class ObjectUri(BaseModel):
164+
name: str
165+
label: str
166+
hash: str = Field(repr=False)
167+
schema_: AnyDict
168+
schema_hash: str = Field(repr=False)
169+
uri: str
170+
size: int
171+
172+
_value: t.Any = PrivateAttr(default=None)
173+
174+
@cached_property
175+
def value(self) -> t.Any:
176+
if self._value is not None:
177+
return self._value
178+
179+
try:
180+
response = requests.get(self.uri, timeout=5)
181+
response.raise_for_status()
182+
self._value = response.text
183+
except requests.RequestException as e:
184+
raise RuntimeError(f"Failed to fetch object from {self.uri}") from e
185+
186+
if isinstance(self._value, str):
187+
with contextlib.suppress(ValidationError):
188+
self._value = TypeAdapter(t.Any).validate_json(self._value)
189+
190+
return self._value
191+
192+
193+
Object = ObjectVal | ObjectUri
194+
195+
196+
class ArtifactFile(BaseModel):
197+
hash: str
198+
uri: str
199+
size_bytes: int
200+
final_real_path: str
201+
202+
203+
class ArtifactDir(BaseModel):
204+
dir_path: str
205+
hash: str
206+
children: list[t.Union["ArtifactDir", ArtifactFile]]
207+
208+
209+
class RunSummary(BaseModel):
133210
id: ULID
134211
name: str
135-
span_id: str
136-
trace_id: str
212+
span_id: str = Field(repr=False)
213+
trace_id: str = Field(repr=False)
137214
timestamp: datetime
138215
duration: int
139216
status: SpanStatus
140217
exception: SpanException | None
141218
tags: set[str]
142-
params: AnyDict
143-
metrics: dict[str, list[Metric]]
144-
inputs: list[ObjectRef]
145-
outputs: list[ObjectRef]
146-
objects: dict[str, Object]
147-
object_schemas: AnyDict
148-
schema_: AnyDict = Field(alias="schema")
219+
params: AnyDict = Field(repr=False)
220+
metrics: dict[str, list[Metric]] = Field(repr=False)
149221

150222

151-
class Task(BaseModel):
223+
class RawRun(RunSummary):
224+
inputs: list[ObjectRef] = Field(repr=False)
225+
outputs: list[ObjectRef] = Field(repr=False)
226+
objects: dict[str, RawObject] = Field(repr=False)
227+
object_schemas: AnyDict = Field(repr=False)
228+
artifacts: list[ArtifactDir] = Field(repr=False)
229+
schema_: AnyDict = Field(alias="schema", repr=False)
230+
231+
232+
class Run(RunSummary):
233+
inputs: dict[str, Object] = Field(repr=False)
234+
outputs: dict[str, Object] = Field(repr=False)
235+
artifacts: list[ArtifactDir] = Field(repr=False)
236+
schema_: AnyDict = Field(alias="schema", repr=False)
237+
238+
239+
class _Task(BaseModel):
152240
name: str
153241
span_id: str
154-
trace_id: str
155-
parent_span_id: str | None
156-
parent_task_span_id: str | None
242+
trace_id: str = Field(repr=False)
243+
parent_span_id: str | None = Field(repr=False)
244+
parent_task_span_id: str | None = Field(repr=False)
157245
timestamp: datetime
158246
duration: int
159247
status: SpanStatus
160248
exception: SpanException | None
161249
tags: set[str]
162-
params: AnyDict
163-
metrics: dict[str, list[Metric]]
164-
inputs: list[ObjectRef] | list[V0Object] # v0 compat
165-
outputs: list[ObjectRef] | list[V0Object] # v0 compat
166-
schema_: AnyDict = Field(alias="schema")
167-
attributes: AnyDict
168-
resource_attributes: AnyDict
169-
events: list[SpanEvent]
170-
links: list[SpanLink]
250+
params: AnyDict = Field(repr=False)
251+
metrics: dict[str, list[Metric]] = Field(repr=False)
252+
schema_: AnyDict = Field(alias="schema", repr=False)
253+
attributes: AnyDict = Field(repr=False)
254+
resource_attributes: AnyDict = Field(repr=False)
255+
events: list[SpanEvent] = Field(repr=False)
256+
links: list[SpanLink] = Field(repr=False)
257+
258+
259+
class RawTask(_Task):
260+
inputs: list[ObjectRef] | list[V0Object] = Field(repr=False)
261+
outputs: list[ObjectRef] | list[V0Object] = Field(repr=False)
262+
263+
264+
class Task(_Task):
265+
inputs: dict[str, Object] = Field(repr=False)
266+
outputs: dict[str, Object] = Field(repr=False)
171267

172268

173269
class Project(BaseModel):
174-
id: UUID
270+
id: UUID = Field(repr=False)
175271
key: str
176272
name: str
177-
description: str | None
273+
description: str | None = Field(repr=False)
178274
created_at: datetime
179275
updated_at: datetime
180276
run_count: int
181-
last_run: Run | None
277+
last_run: RawRun | None = Field(repr=False)
182278

183279

184280
# Derived types

0 commit comments

Comments
 (0)