Skip to content

Commit 15fa3cf

Browse files
authored
tensorflow: Add members from tensorflow.keras.metrics (#11329)
Partially taken from: https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/keras/metrics.pyi
1 parent 69354d7 commit 15fa3cf

File tree

1 file changed

+113
-2
lines changed

1 file changed

+113
-2
lines changed
Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,120 @@
1-
from tensorflow import Tensor
2-
from tensorflow._aliases import TensorCompatible
1+
from _typeshed import Incomplete
2+
from abc import ABCMeta, abstractmethod
3+
from collections.abc import Callable, Iterable, Sequence
4+
from typing import Any, Literal
5+
from typing_extensions import Self, TypeAlias, override
36

7+
import tensorflow as tf
8+
from tensorflow import Operation, Tensor
9+
from tensorflow._aliases import DTypeLike, KerasSerializable, TensorCompatible
10+
from tensorflow.keras.initializers import _Initializer
11+
12+
_Output: TypeAlias = Tensor | dict[str, Tensor]
13+
14+
class Metric(tf.keras.layers.Layer[tf.Tensor, tf.Tensor], metaclass=ABCMeta):
15+
def __init__(self, name: str | None = None, dtype: DTypeLike | None = None) -> None: ...
16+
def __new__(cls, *args: Any, **kwargs: Any) -> Self: ...
17+
def merge_state(self, metrics: Iterable[Self]) -> list[Operation]: ...
18+
def reset_state(self) -> None: ...
19+
@abstractmethod
20+
def update_state(
21+
self, y_true: TensorCompatible, y_pred: TensorCompatible, sample_weight: TensorCompatible | None = None
22+
) -> Operation | None: ...
23+
@abstractmethod
24+
def result(self) -> _Output: ...
25+
# Metric inherits from keras.Layer, but its add_weight method is incompatible with the one from "Layer".
26+
@override
27+
def add_weight( # type: ignore
28+
self,
29+
name: str,
30+
shape: Iterable[int | None] | None = (),
31+
aggregation: tf.VariableAggregation = ...,
32+
synchronization: tf.VariableSynchronization = ...,
33+
initializer: _Initializer | None = None,
34+
dtype: DTypeLike | None = None,
35+
) -> None: ...
36+
37+
class AUC(Metric):
38+
_from_logits: bool
39+
_num_labels: int
40+
num_labels: int | None
41+
def __init__(
42+
self,
43+
num_thresholds: int = 200,
44+
curve: Literal["ROC", "PR"] = "ROC",
45+
summation_method: Literal["interpolation", "minoring", "majoring"] = "interpolation",
46+
name: str | None = None,
47+
dtype: DTypeLike | None = None,
48+
thresholds: Sequence[float] | None = None,
49+
multi_label: bool = False,
50+
num_labels: int | None = None,
51+
label_weights: TensorCompatible | None = None,
52+
from_logits: bool = False,
53+
) -> None: ...
54+
def update_state(
55+
self, y_true: TensorCompatible, y_pred: TensorCompatible, sample_weight: TensorCompatible | None = None
56+
) -> Operation: ...
57+
def result(self) -> tf.Tensor: ...
58+
59+
class Precision(Metric):
60+
def __init__(
61+
self,
62+
thresholds: float | Sequence[float] | None = None,
63+
top_k: int | None = None,
64+
class_id: int | None = None,
65+
name: str | None = None,
66+
dtype: DTypeLike | None = None,
67+
) -> None: ...
68+
def update_state(
69+
self, y_true: TensorCompatible, y_pred: TensorCompatible, sample_weight: TensorCompatible | None = None
70+
) -> Operation: ...
71+
def result(self) -> tf.Tensor: ...
72+
73+
class Recall(Metric):
74+
def __init__(
75+
self,
76+
thresholds: float | Sequence[float] | None = None,
77+
top_k: int | None = None,
78+
class_id: int | None = None,
79+
name: str | None = None,
80+
dtype: DTypeLike | None = None,
81+
) -> None: ...
82+
def update_state(
83+
self, y_true: TensorCompatible, y_pred: TensorCompatible, sample_weight: TensorCompatible | None = None
84+
) -> Operation: ...
85+
def result(self) -> tf.Tensor: ...
86+
87+
class MeanMetricWrapper(Metric):
88+
def __init__(
89+
self, fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], name: str | None = None, dtype: DTypeLike | None = None
90+
) -> None: ...
91+
def update_state(
92+
self, y_true: TensorCompatible, y_pred: TensorCompatible, sample_weight: TensorCompatible | None = None
93+
) -> Operation: ...
94+
def result(self) -> tf.Tensor: ...
95+
96+
class BinaryAccuracy(MeanMetricWrapper):
97+
def __init__(self, name: str | None = "binary_accuracy", dtype: DTypeLike | None = None, threshold: float = 0.5) -> None: ...
98+
99+
class Accuracy(MeanMetricWrapper):
100+
def __init__(self, name: str | None = "accuracy", dtype: DTypeLike | None = None) -> None: ...
101+
102+
class CategoricalAccuracy(MeanMetricWrapper):
103+
def __init__(self, name: str | None = "categorical_accuracy", dtype: DTypeLike | None = None) -> None: ...
104+
105+
class TopKCategoricalAccuracy(MeanMetricWrapper):
106+
def __init__(self, k: int = 5, name: str | None = "top_k_categorical_accuracy", dtype: DTypeLike | None = None) -> None: ...
107+
108+
class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
109+
def __init__(
110+
self, k: int = 5, name: str | None = "sparse_top_k_categorical_accuracy", dtype: DTypeLike | None = None
111+
) -> None: ...
112+
113+
def serialize(metric: KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...
4114
def binary_crossentropy(
5115
y_true: TensorCompatible, y_pred: TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1
6116
) -> Tensor: ...
7117
def categorical_crossentropy(
8118
y_true: TensorCompatible, y_pred: TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1
9119
) -> Tensor: ...
120+
def __getattr__(name: str) -> Incomplete: ...

0 commit comments

Comments
 (0)