Skip to content

Commit 47c0e2d

Browse files
authored
[tensorflow]: Add a few missing elements (#15265)
1 parent 7d70675 commit 47c0e2d

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

stubs/tensorflow/tensorflow/__init__.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ from tensorflow import (
1818
io as io,
1919
keras as keras,
2020
math as math,
21+
nn as nn,
2122
random as random,
2223
types as types,
2324
)
@@ -37,7 +38,7 @@ from tensorflow.core.protobuf import struct_pb2
3738
from tensorflow.dtypes import *
3839
from tensorflow.experimental.dtensor import Layout
3940
from tensorflow.keras import losses as losses
40-
from tensorflow.linalg import eye as eye
41+
from tensorflow.linalg import eye as eye, matmul as matmul
4142

4243
# Most tf.math functions are exported as tf, but sadly not all are.
4344
from tensorflow.math import (
@@ -441,4 +442,10 @@ def gather_nd(
441442
name: str | None = None,
442443
bad_indices_policy: Literal["", "DEFAULT", "ERROR", "IGNORE"] = "",
443444
) -> Tensor: ...
445+
def transpose(
446+
a: Tensor, perm: Sequence[int] | IntArray | None = None, conjugate: _bool = False, name: str = "transpose"
447+
) -> Tensor: ...
448+
def clip_by_value(
449+
t: Tensor | IndexedSlices, clip_value_min: TensorCompatible, clip_value_max: TensorCompatible, name: str | None = None
450+
) -> Tensor: ...
444451
def __getattr__(name: str): ... # incomplete module

stubs/tensorflow/tensorflow/keras/metrics.pyi

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from _typeshed import Incomplete
12
from abc import ABCMeta, abstractmethod
23
from collections.abc import Callable, Iterable, Sequence
3-
from typing import Any, Literal
4+
from enum import Enum
5+
from typing import Any, Literal, type_check_only
46
from typing_extensions import Self, TypeAlias
57

68
import tensorflow as tf
@@ -107,6 +109,25 @@ class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
107109
self, k: int = 5, name: str | None = "sparse_top_k_categorical_accuracy", dtype: DTypeLike | None = None
108110
) -> None: ...
109111

112+
# TODO: Actually tensorflow.python.keras.utils.metrics_utils.Reduction, but that module
113+
# is currently missing from the stub.
114+
@type_check_only
115+
class _Reduction(Enum):
116+
SUM = "sum"
117+
SUM_OVER_BATCH_SIZE = "sum_over_batch_size"
118+
WEIGHTED_MEAN = "weighted_mean"
119+
120+
class Reduce(Metric):
121+
reduction: _Reduction
122+
total: Incomplete
123+
count: Incomplete # only defined for some reductions
124+
def __init__(self, reduction: _Reduction, name: str | None, dtype: DTypeLike | None = None) -> None: ...
125+
def update_state(self, values, sample_weight=None): ... # type: ignore[override]
126+
def result(self) -> Tensor: ...
127+
128+
class Mean(Reduce):
129+
def __init__(self, name: str | None = "mean", dtype: DTypeLike | None = None) -> None: ...
130+
110131
def serialize(metric: KerasSerializable) -> dict[str, Any]: ...
111132
def binary_crossentropy(
112133
y_true: TensorCompatible, y_pred: TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1

0 commit comments

Comments
 (0)