Skip to content

Commit e2e2b6e

Browse files
committed
refactor: type metrics_extractor as Callable[[Any], Optional[LDAIMetrics]], remove defensive getattr
1 parent 7c764b8 commit e2e2b6e

1 file changed

Lines changed: 14 additions & 6 deletions

File tree

packages/sdk/server-ai/src/ldai/tracker.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
from __future__ import annotations
2+
13
import base64
24
import json
35
import time
46
import warnings
57
from dataclasses import dataclass
68
from enum import Enum
7-
from typing import Any, Callable, Dict, Iterable, List, Optional
9+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional
810

911
from ldclient import Context, LDClient, Result
1012

1113
from ldai import log
1214

15+
if TYPE_CHECKING:
16+
from ldai.providers.types import LDAIMetrics
17+
1318

1419
class FeedbackKind(Enum):
1520
"""
@@ -282,7 +287,7 @@ def track_duration_of(self, func):
282287
def _track_from_metrics_extractor(
283288
self,
284289
result: Any,
285-
metrics_extractor: Callable[[Any], Any],
290+
metrics_extractor: Callable[[Any], Optional[LDAIMetrics]],
286291
elapsed_ms: int,
287292
) -> None:
288293
metrics = None
@@ -295,8 +300,7 @@ def _track_from_metrics_extractor(
295300
self.track_duration(elapsed_ms)
296301
return
297302

298-
reported_ms = getattr(metrics, 'duration_ms', None)
299-
self.track_duration(reported_ms if reported_ms is not None else elapsed_ms)
303+
self.track_duration(metrics.duration_ms if metrics.duration_ms is not None else elapsed_ms)
300304
if metrics.success:
301305
self.track_success()
302306
else:
@@ -308,7 +312,7 @@ def _track_from_metrics_extractor(
308312

309313
def track_metrics_of(
310314
self,
311-
metrics_extractor: Callable[[Any], Any],
315+
metrics_extractor: Callable[[Any], Optional[LDAIMetrics]],
312316
func: Callable[[], Any],
313317
) -> Any:
314318
"""
@@ -344,7 +348,11 @@ def track_metrics_of(
344348
self._track_from_metrics_extractor(result, metrics_extractor, elapsed_ms)
345349
return result
346350

347-
async def track_metrics_of_async(self, metrics_extractor, func):
351+
async def track_metrics_of_async(
352+
self,
353+
metrics_extractor: Callable[[Any], Optional[LDAIMetrics]],
354+
func: Callable[[], Any],
355+
) -> Any:
348356
"""
349357
Track metrics for an async AI operation (``func`` is awaited).
350358

0 commit comments

Comments
 (0)