Skip to content

Commit cff60c9

Browse files
committed
hotfix: Metric modes
1 parent ed7e687 commit cff60c9

3 files changed

Lines changed: 33 additions & 31 deletions

File tree

dreadnode/main.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
ENV_SERVER,
3333
ENV_SERVER_URL,
3434
)
35-
from dreadnode.metric import Metric, MetricMode, Scorer, ScorerCallable, T
35+
from dreadnode.metric import Metric, MetricAggMode, Scorer, ScorerCallable, T
3636
from dreadnode.task import P, R, Task
3737
from dreadnode.tracing.exporters import (
3838
FileExportConfig,
@@ -757,7 +757,7 @@ def log_metric(
757757
step: int = 0,
758758
origin: t.Any | None = None,
759759
timestamp: datetime | None = None,
760-
mode: MetricMode = "direct",
760+
mode: MetricAggMode | None = None,
761761
to: ToObject = "task-or-run",
762762
) -> None:
763763
"""
@@ -799,7 +799,7 @@ def log_metric(
799799
value: Metric,
800800
*,
801801
origin: t.Any | None = None,
802-
mode: MetricMode = "direct",
802+
mode: MetricAggMode | None = None,
803803
to: ToObject = "task-or-run",
804804
) -> None:
805805
"""
@@ -821,7 +821,6 @@ def log_metric(
821821
as an input or output anywhere in the run.
822822
mode: The aggregation mode to use for the metric. Helpful when you want to let
823823
the library take care of translating your raw values into better representations.
824-
- direct: do not modify the value at all (default)
825824
- min: always report the lowest ovbserved value for this metric
826825
- max: always report the highest observed value for this metric
827826
- avg: report the average of all values for this metric
@@ -841,7 +840,7 @@ def log_metric(
841840
step: int = 0,
842841
origin: t.Any | None = None,
843842
timestamp: datetime | None = None,
844-
mode: MetricMode = "direct",
843+
mode: MetricAggMode | None = None,
845844
to: ToObject = "task-or-run",
846845
) -> None:
847846
task = current_task_span.get()

dreadnode/metric.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
from dataclasses import dataclass, field
44
from datetime import datetime, timezone
55

6+
from logfire._internal.stack_info import warn_at_user_stacklevel
67
from logfire._internal.utils import safe_repr
78
from opentelemetry.trace import Tracer
89

910
from dreadnode.types import JsonDict, JsonValue
1011

1112
T = t.TypeVar("T")
1213

13-
MetricMode = t.Literal["direct", "avg", "sum", "min", "max", "count"]
14+
MetricAggMode = t.Literal["avg", "sum", "min", "max", "count"]
15+
16+
17+
class MetricWarning(UserWarning):
18+
pass
1419

1520

1621
@dataclass
@@ -57,7 +62,7 @@ def from_many(
5762
score_attributes = {name: value for name, value, _ in values}
5863
return cls(value=total / weight, step=step, attributes={**attributes, **score_attributes})
5964

60-
def apply_mode(self, mode: MetricMode, others: "list[Metric]") -> "Metric":
65+
def apply_mode(self, mode: MetricAggMode, others: "list[Metric]") -> "Metric":
6166
"""
6267
Apply an aggregation mode to the metric.
6368
This will modify the metric in place.
@@ -69,15 +74,13 @@ def apply_mode(self, mode: MetricMode, others: "list[Metric]") -> "Metric":
6974
Returns:
7075
self
7176
"""
72-
previous_mode = next((m.attributes.get("mode") for m in others), mode) or "direct"
73-
if mode != previous_mode:
74-
raise ValueError(
75-
f"Cannot mix metric modes {mode} != {previous_mode}",
77+
previous_mode = next((m.attributes.get("mode") for m in others), mode)
78+
if previous_mode is not None and mode != previous_mode:
79+
warn_at_user_stacklevel(
80+
f"Metric logged with different modes ({mode} != {previous_mode}). This may result in unexpected behavior.",
81+
MetricWarning,
7682
)
7783

78-
if mode == "direct":
79-
return self
80-
8184
self.attributes["original"] = self.value
8285
self.attributes["mode"] = mode
8386

dreadnode/tracing/span.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from dreadnode.artifact.storage import ArtifactStorage
2929
from dreadnode.artifact.tree_builder import ArtifactTreeBuilder, DirectoryNode
3030
from dreadnode.constants import MAX_INLINE_OBJECT_BYTES
31-
from dreadnode.metric import Metric, MetricDict, MetricMode
31+
from dreadnode.metric import Metric, MetricAggMode, MetricDict
3232
from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal
3333
from dreadnode.serialization import Serialized, serialize
3434
from dreadnode.types import UNSET, AnyDict, JsonDict, JsonValue, Unset
@@ -526,9 +526,8 @@ def log_metric(
526526
step: int = 0,
527527
origin: t.Any | None = None,
528528
timestamp: datetime | None = None,
529-
mode: MetricMode = "direct",
530-
) -> None:
531-
...
529+
mode: MetricAggMode | None = None,
530+
) -> None: ...
532531

533532
@t.overload
534533
def log_metric(
@@ -537,9 +536,8 @@ def log_metric(
537536
value: Metric,
538537
*,
539538
origin: t.Any | None = None,
540-
mode: MetricMode = "direct",
541-
) -> None:
542-
...
539+
mode: MetricAggMode | None = None,
540+
) -> None: ...
543541

544542
def log_metric(
545543
self,
@@ -549,7 +547,7 @@ def log_metric(
549547
step: int = 0,
550548
origin: t.Any | None = None,
551549
timestamp: datetime | None = None,
552-
mode: MetricMode = "direct",
550+
mode: MetricAggMode | None = None,
553551
) -> None:
554552
metric = (
555553
value
@@ -566,7 +564,9 @@ def log_metric(
566564
metric.attributes[METRIC_ATTRIBUTE_SOURCE_HASH] = origin_hash
567565

568566
metrics = self._metrics.setdefault(key, [])
569-
metrics.append(metric.apply_mode(mode, metrics))
567+
if mode is not None:
568+
metric = metric.apply_mode(mode, metrics)
569+
metrics.append(metric)
570570

571571
@property
572572
def outputs(self) -> AnyDict:
@@ -739,9 +739,8 @@ def log_metric(
739739
step: int = 0,
740740
origin: t.Any | None = None,
741741
timestamp: datetime | None = None,
742-
mode: MetricMode = "direct",
743-
) -> None:
744-
...
742+
mode: MetricAggMode | None = None,
743+
) -> None: ...
745744

746745
@t.overload
747746
def log_metric(
@@ -750,9 +749,8 @@ def log_metric(
750749
value: Metric,
751750
*,
752751
origin: t.Any | None = None,
753-
mode: MetricMode = "direct",
754-
) -> None:
755-
...
752+
mode: MetricAggMode | None = None,
753+
) -> None: ...
756754

757755
def log_metric(
758756
self,
@@ -762,7 +760,7 @@ def log_metric(
762760
step: int = 0,
763761
origin: t.Any | None = None,
764762
timestamp: datetime | None = None,
765-
mode: MetricMode = "direct",
763+
mode: MetricAggMode | None = None,
766764
) -> None:
767765
metric = (
768766
value
@@ -779,7 +777,9 @@ def log_metric(
779777
metric.attributes[METRIC_ATTRIBUTE_SOURCE_HASH] = origin_hash
780778

781779
metrics = self._metrics.setdefault(key, [])
782-
metrics.append(metric.apply_mode(mode, metrics))
780+
if mode is not None:
781+
metric = metric.apply_mode(mode, metrics)
782+
metrics.append(metric)
783783

784784
# For every metric we log, also log it to the run
785785
# with our `label` as a prefix.

0 commit comments

Comments
 (0)