Skip to content

Commit 62fe644

Browse files
committed
Add metric agg mode locally. Add task execution stats. Clean up dependencies.
1 parent ec260df commit 62fe644

6 files changed

Lines changed: 1197 additions & 25 deletions

File tree

dreadnode/main.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
ENV_SERVER,
3333
ENV_SERVER_URL,
3434
)
35-
from dreadnode.metric import Metric, Scorer, ScorerCallable, T
35+
from dreadnode.metric import Metric, MetricMode, Scorer, ScorerCallable, T
3636
from dreadnode.task import P, R, Task
3737
from dreadnode.tracing.exporters import (
3838
FileExportConfig,
@@ -757,6 +757,7 @@ def log_metric(
757757
step: int = 0,
758758
origin: t.Any | None = None,
759759
timestamp: datetime | None = None,
760+
mode: MetricMode = "direct",
760761
to: ToObject = "task-or-run",
761762
) -> None:
762763
"""
@@ -778,6 +779,14 @@ def log_metric(
778779
origin: The origin of the metric - can be provided any object which was logged
779780
as an input or output anywhere in the run.
780781
timestamp: The timestamp of the metric - defaults to the current time.
782+
mode: The aggregation mode to use for the metric. Helpful when you want to let
783+
the library take care of translating your raw values into better representations.
784+
- direct: do not modify the value at all (default)
785+
- min: the lowest observed value reported for this metric
786+
- max: the highest observed value reported for this metric
787+
- avg: the average of all reported values for this metric
788+
- sum: the cumulative sum of all reported values for this metric
789+
- count: increment every time this metric is logged - disregard value
781790
to: The target object to log the metric to. Can be "task-or-run" or "run".
782791
Defaults to "task-or-run". If "task-or-run", the metric will be logged
783792
to the current task or run, whichever is the nearest ancestor.
@@ -790,6 +799,7 @@ def log_metric(
790799
value: Metric,
791800
*,
792801
origin: t.Any | None = None,
802+
mode: MetricMode = "direct",
793803
to: ToObject = "task-or-run",
794804
) -> None:
795805
"""
@@ -809,6 +819,13 @@ def log_metric(
809819
value: The metric object.
810820
origin: The origin of the metric - can be provided any object which was logged
811821
as an input or output anywhere in the run.
822+
mode: The aggregation mode to use for the metric. Helpful when you want to let
823+
the library take care of translating your raw values into better representations.
824+
- direct: do not modify the value at all (default)
825+
- min: always report the lowest ovbserved value for this metric
826+
- max: always report the highest observed value for this metric
827+
- sum: report a rolling sum of all values for this metric
828+
- count: report the number of times this metric has been logged
812829
to: The target object to log the metric to. Can be "task-or-run" or "run".
813830
Defaults to "task-or-run". If "task-or-run", the metric will be logged
814831
to the current task or run, whichever is the nearest ancestor.
@@ -824,6 +841,7 @@ def log_metric(
824841
step: int = 0,
825842
origin: t.Any | None = None,
826843
timestamp: datetime | None = None,
844+
mode: MetricMode = "direct",
827845
to: ToObject = "task-or-run",
828846
) -> None:
829847
task = current_task_span.get()
@@ -838,7 +856,7 @@ def log_metric(
838856
if isinstance(value, Metric)
839857
else Metric(float(value), step, timestamp or datetime.now(timezone.utc))
840858
)
841-
target.log_metric(key, metric, origin=origin)
859+
target.log_metric(key, metric, origin=origin, mode=mode)
842860

843861
@handle_internal_errors()
844862
def log_artifact(

dreadnode/metric.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
T = t.TypeVar("T")
1212

13+
MetricMode = t.Literal["direct", "avg", "sum", "min", "max", "count"]
14+
1315

1416
@dataclass
1517
class Metric:
@@ -55,6 +57,46 @@ def from_many(
5557
score_attributes = {name: value for name, value, _ in values}
5658
return cls(value=total / weight, step=step, attributes={**attributes, **score_attributes})
5759

60+
def apply_mode(self, mode: MetricMode, others: "list[Metric]") -> "Metric":
61+
"""
62+
Apply an aggregation mode to the metric.
63+
This will modify the metric in place.
64+
65+
Args:
66+
mode: The mode to apply. One of "sum", "min", "max", or "inc".
67+
others: A list of other metrics to apply the mode to.
68+
69+
Returns:
70+
self
71+
"""
72+
previous_mode = next((m.attributes.get("mode") for m in others), mode) or "direct"
73+
if mode != previous_mode:
74+
raise ValueError(
75+
f"Cannot mix metric modes {mode} != {previous_mode}",
76+
)
77+
78+
if mode == "direct":
79+
return self
80+
81+
self.attributes["original"] = self.value
82+
self.attributes["mode"] = mode
83+
84+
prior_values = [m.value for m in sorted(others, key=lambda m: m.timestamp)]
85+
86+
if mode == "sum":
87+
self.value += max(prior_values)
88+
elif mode == "min":
89+
self.value = min([self.value, *prior_values])
90+
elif mode == "max":
91+
self.value = max([self.value, *prior_values])
92+
elif mode == "count":
93+
self.value = len(others) + 1
94+
elif mode == "avg" and prior_values:
95+
current_avg = prior_values[-1]
96+
self.value = current_avg + (self.value - current_avg) / (len(prior_values) + 1)
97+
98+
return self
99+
58100

59101
MetricDict = dict[str, list[Metric]]
60102

dreadnode/task.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def top_n(
5252
*,
5353
as_outputs: t.Literal[False] = False,
5454
reverse: bool = True,
55-
) -> "TaskSpanList[R]": ...
55+
) -> "TaskSpanList[R]":
56+
...
5657

5758
@t.overload
5859
def top_n(
@@ -61,7 +62,8 @@ def top_n(
6162
*,
6263
as_outputs: t.Literal[True],
6364
reverse: bool = True,
64-
) -> list[R]: ...
65+
) -> list[R]:
66+
...
6567

6668
def top_n(
6769
self,
@@ -83,7 +85,7 @@ def top_n(
8385
"""
8486
sorted_ = self.sorted(reverse=reverse)[:n]
8587
return (
86-
t.cast(list[R], [span.output for span in sorted_]) # noqa: TC006
88+
t.cast(list[R], [span.output for span in sorted_])
8789
if as_outputs
8890
else TaskSpanList(sorted_)
8991
)
@@ -246,6 +248,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
246248
run_id=run.run_id,
247249
tracer=self.tracer,
248250
) as span:
251+
span.run.log_metric(f"{self.label}.exec.count", 1, mode="count")
252+
249253
for name, value in params_to_log.items():
250254
span.log_param(name, value)
251255

@@ -254,10 +258,15 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
254258
for name, value in inputs_to_log.items()
255259
]
256260

257-
output = t.cast(R | t.Awaitable[R], self.func(*args, **kwargs)) # noqa: TC006
258-
if inspect.isawaitable(output):
259-
output = await output
261+
try:
262+
output = t.cast(R | t.Awaitable[R], self.func(*args, **kwargs))
263+
if inspect.isawaitable(output):
264+
output = await output
265+
except Exception:
266+
span.run.log_metric(f"{self.label}.exec.success_rate", 0, mode="avg")
267+
raise
260268

269+
span.run.log_metric(f"{self.label}.exec.success_rate", 1, mode="avg")
261270
span.output = output
262271

263272
if self.log_output:

dreadnode/tracing/span.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from dreadnode.artifact.storage import ArtifactStorage
2929
from dreadnode.artifact.tree_builder import ArtifactTreeBuilder, DirectoryNode
3030
from dreadnode.constants import MAX_INLINE_OBJECT_BYTES
31-
from dreadnode.metric import Metric, MetricDict
31+
from dreadnode.metric import Metric, MetricDict, MetricMode
3232
from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal
3333
from dreadnode.serialization import Serialized, serialize
3434
from dreadnode.types import UNSET, AnyDict, JsonDict, JsonValue, Unset
@@ -526,7 +526,9 @@ def log_metric(
526526
step: int = 0,
527527
origin: t.Any | None = None,
528528
timestamp: datetime | None = None,
529-
) -> None: ...
529+
mode: MetricMode = "direct",
530+
) -> None:
531+
...
530532

531533
@t.overload
532534
def log_metric(
@@ -535,7 +537,9 @@ def log_metric(
535537
value: Metric,
536538
*,
537539
origin: t.Any | None = None,
538-
) -> None: ...
540+
mode: MetricMode = "direct",
541+
) -> None:
542+
...
539543

540544
def log_metric(
541545
self,
@@ -545,6 +549,7 @@ def log_metric(
545549
step: int = 0,
546550
origin: t.Any | None = None,
547551
timestamp: datetime | None = None,
552+
mode: MetricMode = "direct",
548553
) -> None:
549554
metric = (
550555
value
@@ -560,9 +565,8 @@ def log_metric(
560565
)
561566
metric.attributes[METRIC_ATTRIBUTE_SOURCE_HASH] = origin_hash
562567

563-
self._metrics.setdefault(key, []).append(metric)
564-
if self._span is None:
565-
return
568+
metrics = self._metrics.setdefault(key, [])
569+
metrics.append(metric.apply_mode(mode, metrics))
566570

567571
@property
568572
def outputs(self) -> AnyDict:
@@ -735,7 +739,9 @@ def log_metric(
735739
step: int = 0,
736740
origin: t.Any | None = None,
737741
timestamp: datetime | None = None,
738-
) -> None: ...
742+
mode: MetricMode = "direct",
743+
) -> None:
744+
...
739745

740746
@t.overload
741747
def log_metric(
@@ -744,7 +750,9 @@ def log_metric(
744750
value: Metric,
745751
*,
746752
origin: t.Any | None = None,
747-
) -> None: ...
753+
mode: MetricMode = "direct",
754+
) -> None:
755+
...
748756

749757
def log_metric(
750758
self,
@@ -754,6 +762,7 @@ def log_metric(
754762
step: int = 0,
755763
origin: t.Any | None = None,
756764
timestamp: datetime | None = None,
765+
mode: MetricMode = "direct",
757766
) -> None:
758767
metric = (
759768
value
@@ -769,12 +778,13 @@ def log_metric(
769778
)
770779
metric.attributes[METRIC_ATTRIBUTE_SOURCE_HASH] = origin_hash
771780

772-
self._metrics.setdefault(key, []).append(metric)
781+
metrics = self._metrics.setdefault(key, [])
782+
metrics.append(metric.apply_mode(mode, metrics))
773783

774784
# For every metric we log, also log it to the run
775785
# with our `label` as a prefix.
776786
#
777-
# Don't include `source` as we handled it here.
787+
# Don't include `source` and `mode` as we handled it here.
778788
if (run := current_run_span.get()) is not None:
779789
run.log_metric(f"{self._label}.{key}", metric)
780790

pyproject.toml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,13 @@ readme = "README.md"
4646
python = ">=3.10,<3.13"
4747
pydantic = "^2.9.2"
4848
httpx = "^0.28.0"
49-
ruamel-yaml = "^0.18.6"
5049
logfire = "^3.5.3"
5150
python-ulid = "^3.0.0"
5251
fast-depends = "^2.4.12"
5352
coolname = "^2.2.0"
54-
pandas = "^2.2.3"
55-
pyarrow = "^19.0.1"
56-
loguru = "^0.7.3"
5753
fsspec = { extras = [
5854
"s3",
5955
], version = "2024.12.0" } # pinned this version to be compatible with datasets
60-
pydub = "^0.25.1"
61-
moviepy = "^2.1.2"
62-
datasets = "^3.5.0"
6356

6457
[tool.poetry.group.dev.dependencies]
6558
mypy = "^1.8.0"
@@ -72,6 +65,11 @@ pandas-stubs = "^2.2.3.250308"
7265
types-requests = "^2.32.0.20250306"
7366
rigging = "^2.3.0"
7467
typer = "^0.15.2"
68+
pydub = "^0.25.1"
69+
moviepy = "^2.1.2"
70+
datasets = "^3.5.0"
71+
pandas = "^2.2.3"
72+
pyarrow = "^19.0.1"
7573

7674
[tool.pytest.ini_options]
7775
asyncio_mode = "auto"

0 commit comments

Comments
 (0)