Skip to content

Commit 314f8a0

Browse files
committed
Make some metric key handling stricter. Move attributes for objects to refs instead of events. Add run-level autolog specs and inheritence. Better handling for metric names.
1 parent d67c4c5 commit 314f8a0

5 files changed

Lines changed: 72 additions & 28 deletions

File tree

dreadnode/main.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@
4747
current_task_span,
4848
)
4949
from dreadnode.types import (
50+
INHERITED,
5051
AnyDict,
52+
Inherited,
5153
JsonDict,
5254
JsonValue,
5355
)
@@ -412,8 +414,8 @@ def task(
412414
name: str | None = None,
413415
label: str | None = None,
414416
log_params: t.Sequence[str] | bool = False,
415-
log_inputs: t.Sequence[str] | bool = True,
416-
log_output: bool = True,
417+
log_inputs: t.Sequence[str] | bool | Inherited = INHERITED,
418+
log_output: bool | Inherited = INHERITED,
417419
tags: t.Sequence[str] | None = None,
418420
**attributes: t.Any,
419421
) -> TaskDecorator: ...
@@ -426,8 +428,8 @@ def task(
426428
name: str | None = None,
427429
label: str | None = None,
428430
log_params: t.Sequence[str] | bool = False,
429-
log_inputs: t.Sequence[str] | bool = True,
430-
log_output: bool = True,
431+
log_inputs: t.Sequence[str] | bool | Inherited = INHERITED,
432+
log_output: bool | Inherited = INHERITED,
431433
tags: t.Sequence[str] | None = None,
432434
**attributes: t.Any,
433435
) -> ScoredTaskDecorator[R]: ...
@@ -439,8 +441,8 @@ def task(
439441
name: str | None = None,
440442
label: str | None = None,
441443
log_params: t.Sequence[str] | bool = False,
442-
log_inputs: t.Sequence[str] | bool = True,
443-
log_output: bool = True,
444+
log_inputs: t.Sequence[str] | bool | Inherited = INHERITED,
445+
log_output: bool | Inherited = INHERITED,
444446
tags: t.Sequence[str] | None = None,
445447
**attributes: t.Any,
446448
) -> TaskDecorator:
@@ -622,6 +624,7 @@ def run(
622624
tags: t.Sequence[str] | None = None,
623625
params: AnyDict | None = None,
624626
project: str | None = None,
627+
autolog: bool = True,
625628
**attributes: t.Any,
626629
) -> RunSpan:
627630
"""
@@ -647,6 +650,7 @@ def run(
647650
project: The project name to associate the run with. If not provided,
648651
the project passed to `configure()` will be used, or the
649652
run will be associated with a default project.
653+
autolog: Whether to automatically log task inputs, outputs, and execution metrics if unspecified.
650654
**attributes: Additional attributes to attach to the run span.
651655
"""
652656
if not self._initialized:
@@ -664,6 +668,7 @@ def run(
664668
tags=tags,
665669
file_system=self._fs,
666670
prefix_path=self._fs_prefix,
671+
autolog=autolog,
667672
)
668673

669674
@handle_internal_errors()

dreadnode/object.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import typing as t
22
from dataclasses import dataclass
33

4+
from dreadnode.types import JsonDict
5+
46

57
@dataclass
68
class ObjectRef:
79
name: str
810
label: str
911
hash: str
12+
attributes: JsonDict
1013

1114

1215
@dataclass

dreadnode/task.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from dreadnode.metric import Scorer, ScorerCallable
1111
from dreadnode.tracing.span import TaskSpan, current_run_span
12+
from dreadnode.types import INHERITED, Inherited
1213

1314
P = t.ParamSpec("P")
1415
R = t.TypeVar("R")
@@ -114,9 +115,9 @@ class Task(t.Generic[P, R]):
114115

115116
log_params: t.Sequence[str] | bool = False
116117
"Whether to log all, or specific, incoming arguments to the function as parameters."
117-
log_inputs: t.Sequence[str] | bool = True
118+
log_inputs: t.Sequence[str] | bool | Inherited = INHERITED
118119
"Whether to log all, or specific, incoming arguments to the function as inputs."
119-
log_output: bool = True
120+
log_output: bool | Inherited = INHERITED
120121
"Whether to automatically log the result of the function as an output."
121122

122123
def __post_init__(self) -> None:
@@ -239,6 +240,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
239240
if run is None or not run.is_recording:
240241
raise RuntimeError("Tasks must be executed within a run")
241242

243+
log_inputs = run.autolog if isinstance(self.log_inputs, Inherited) else self.log_inputs
244+
log_output = run.autolog if isinstance(self.log_output, Inherited) else self.log_output
245+
242246
bound_args = self._bind_args(*args, **kwargs)
243247

244248
params_to_log = (
@@ -250,9 +254,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
250254
)
251255
inputs_to_log = (
252256
bound_args
253-
if self.log_inputs is True
254-
else {k: v for k, v in bound_args.items() if k in self.log_inputs}
255-
if self.log_inputs is not False
257+
if log_inputs is True
258+
else {k: v for k, v in bound_args.items() if k in log_inputs}
259+
if log_inputs is not False
256260
else {}
257261
)
258262

@@ -265,13 +269,16 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
265269
run_id=run.run_id,
266270
tracer=self.tracer,
267271
) as span:
268-
span.run.log_metric(f"{self.label}.exec.count", 1, mode="count")
272+
if run.autolog:
273+
span.run.log_metric(
274+
"count", 1, prefix=f"{self.label}.exec", mode="count", attributes={"auto": True}
275+
)
269276

270277
for name, value in params_to_log.items():
271278
span.log_param(name, value)
272279

273280
input_object_hashes: list[str] = [
274-
span.log_input(name, value, label=f"{self.label}.input.{name}")
281+
span.log_input(name, value, label=f"{self.label}.input.{name}", auto=True)
275282
for name, value in inputs_to_log.items()
276283
]
277284

@@ -280,17 +287,29 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
280287
if inspect.isawaitable(output):
281288
output = await output
282289
except Exception:
283-
span.run.log_metric(f"{self.label}.exec.success_rate", 0, mode="avg")
290+
if run.autolog:
291+
span.run.log_metric(
292+
"success_rate",
293+
0,
294+
prefix=f"{self.label}.exec",
295+
mode="avg",
296+
attributes={"auto": True},
297+
)
284298
raise
285299

286-
span.run.log_metric(f"{self.label}.exec.success_rate", 1, mode="avg")
300+
if run.autolog:
301+
span.run.log_metric(
302+
"success_rate",
303+
1,
304+
prefix=f"{self.label}.exec",
305+
mode="avg",
306+
attributes={"auto": True},
307+
)
287308
span.output = output
288309

289-
if self.log_output:
310+
if log_output:
290311
output_object_hash = span.log_output(
291-
"output",
292-
output,
293-
label=f"{self.label}.output",
312+
"output", output, label=f"{self.label}.output", auto=True
294313
)
295314

296315
# Link the output to the inputs

dreadnode/tracing/span.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,15 @@ def __init__(
250250
tracer: Tracer,
251251
file_system: AbstractFileSystem,
252252
prefix_path: str,
253+
*,
253254
params: AnyDict | None = None,
254255
metrics: MetricDict | None = None,
255256
run_id: str | None = None,
256257
tags: t.Sequence[str] | None = None,
258+
autolog: bool = True,
257259
) -> None:
260+
self.autolog = autolog
261+
258262
self._params = params or {}
259263
self._metrics = metrics or {}
260264
self._objects: dict[str, Object] = {}
@@ -486,9 +490,8 @@ def log_input(
486490
value,
487491
label=label,
488492
event_name=EVENT_NAME_OBJECT_INPUT,
489-
**attributes,
490493
)
491-
self._inputs.append(ObjectRef(name, label=label, hash=hash_))
494+
self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))
492495

493496
def log_artifact(
494497
self,
@@ -528,6 +531,7 @@ def log_metric(
528531
origin: t.Any | None = None,
529532
timestamp: datetime | None = None,
530533
mode: MetricAggMode | None = None,
534+
prefix: str | None = None,
531535
attributes: JsonDict | None = None,
532536
) -> None: ...
533537

@@ -539,6 +543,7 @@ def log_metric(
539543
*,
540544
origin: t.Any | None = None,
541545
mode: MetricAggMode | None = None,
546+
prefix: str | None = None,
542547
) -> None: ...
543548

544549
def log_metric(
@@ -550,6 +555,7 @@ def log_metric(
550555
origin: t.Any | None = None,
551556
timestamp: datetime | None = None,
552557
mode: MetricAggMode | None = None,
558+
prefix: str | None = None,
553559
attributes: JsonDict | None = None,
554560
) -> None:
555561
metric = (
@@ -560,6 +566,10 @@ def log_metric(
560566
)
561567
)
562568

569+
key = re.sub(r"[^\w/]+", "_", key.lower())
570+
if prefix is not None:
571+
key = f"{prefix}.{key}"
572+
563573
if origin is not None:
564574
origin_hash = self.log_object(
565575
origin,
@@ -590,9 +600,8 @@ def log_output(
590600
value,
591601
label=label,
592602
event_name=EVENT_NAME_OBJECT_OUTPUT,
593-
**attributes,
594603
)
595-
self._outputs.append(ObjectRef(name, label=label, hash=hash_))
604+
self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))
596605

597606

598607
class TaskSpan(Span, t.Generic[R]):
@@ -694,9 +703,8 @@ def log_output(
694703
value,
695704
label=label,
696705
event_name=EVENT_NAME_OBJECT_OUTPUT,
697-
**attributes,
698706
)
699-
self._outputs.append(ObjectRef(name, label=label, hash=hash_))
707+
self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))
700708
return hash_
701709

702710
@property
@@ -726,9 +734,8 @@ def log_input(
726734
value,
727735
label=label,
728736
event_name=EVENT_NAME_OBJECT_INPUT,
729-
**attributes,
730737
)
731-
self._inputs.append(ObjectRef(name, label=label, hash=hash_))
738+
self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))
732739
return hash_
733740

734741
@property
@@ -777,6 +784,8 @@ def log_metric(
777784
)
778785
)
779786

787+
key = re.sub(r"[^\w/]+", "_", key.lower())
788+
780789
if origin is not None:
781790
origin_hash = self.run.log_object(
782791
origin,
@@ -795,7 +804,7 @@ def log_metric(
795804
#
796805
# Don't include `source` and `mode` as we handled it here.
797806
if (run := current_run_span.get()) is not None:
798-
run.log_metric(f"{self._label}.{key}", metric)
807+
run.log_metric(key, metric, prefix=self._label)
799808

800809
def get_average_metric_value(self, key: str | None = None) -> float:
801810
metrics = (

dreadnode/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,11 @@ def __bool__(self) -> t.Literal[False]:
2323

2424

2525
UNSET: Unset = Unset()
26+
27+
28+
class Inherited:
29+
def __repr__(self) -> str:
30+
return "Inherited"
31+
32+
33+
INHERITED: Inherited = Inherited()

0 commit comments

Comments
 (0)