|
| 1 | +from _typeshed import Incomplete |
1 | 2 | from abc import ABCMeta, abstractmethod |
2 | 3 | from collections.abc import Callable, Iterable, Sequence |
3 | | -from typing import Any, Literal |
| 4 | +from enum import Enum |
| 5 | +from typing import Any, Literal, type_check_only |
4 | 6 | from typing_extensions import Self, TypeAlias |
5 | 7 |
|
6 | 8 | import tensorflow as tf |
@@ -107,6 +109,25 @@ class SparseTopKCategoricalAccuracy(MeanMetricWrapper): |
107 | 109 | self, k: int = 5, name: str | None = "sparse_top_k_categorical_accuracy", dtype: DTypeLike | None = None |
108 | 110 | ) -> None: ... |
109 | 111 |
|
| 112 | +# TODO: Actually tensorflow.python.keras.utils.metrics_utils.Reduction, but that module |
| 113 | +# is currently missing from the stub. |
| 114 | +@type_check_only |
| 115 | +class _Reduction(Enum): |
| 116 | + SUM = "sum" |
| 117 | + SUM_OVER_BATCH_SIZE = "sum_over_batch_size" |
| 118 | + WEIGHTED_MEAN = "weighted_mean" |
| 119 | + |
| 120 | +class Reduce(Metric): |
| 121 | + reduction: _Reduction |
| 122 | + total: Incomplete |
| 123 | + count: Incomplete # only defined for some reductions |
| 124 | + def __init__(self, reduction: _Reduction, name: str | None, dtype: DTypeLike | None = None) -> None: ... |
| 125 | + def update_state(self, values, sample_weight=None): ... # type: ignore[override] |
| 126 | + def result(self) -> Tensor: ... |
| 127 | + |
| 128 | +class Mean(Reduce): |
| 129 | + def __init__(self, name: str | None = "mean", dtype: DTypeLike | None = None) -> None: ... |
| 130 | + |
110 | 131 | def serialize(metric: KerasSerializable) -> dict[str, Any]: ... |
111 | 132 | def binary_crossentropy( |
112 | 133 | y_true: TensorCompatible, y_pred: TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1 |
|
0 commit comments