Skip to content

Commit 6087745

Browse files
authored
tensorflow: add tf.ones, tf.zeros, tf.zeros_like and tf.ones_like functions (#11368)
1 parent d93ee88 commit 6087745

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

stubs/tensorflow/tensorflow/__init__.pyi

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ from tensorflow.core.protobuf import struct_pb2
3636
# is necessary to avoid a crash in pytype.
3737
from tensorflow.dtypes import *
3838
from tensorflow.dtypes import DType as DType
39+
from tensorflow.experimental.dtensor import Layout
3940
from tensorflow.keras import losses as losses
4041

4142
# Most tf.math functions are exported as tf, but sadly not all are.
@@ -438,5 +439,23 @@ def cast(x: TensorCompatible, dtype: DTypeLike, name: str | None = None) -> Tens
438439
def cast(x: SparseTensor, dtype: DTypeLike, name: str | None = None) -> SparseTensor: ...
439440
@overload
440441
def cast(x: RaggedTensor, dtype: DTypeLike, name: str | None = None) -> RaggedTensor: ...
442+
def zeros(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ...
443+
def ones(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ...
444+
@overload
445+
def zeros_like(
446+
input: TensorCompatible | IndexedSlices, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
447+
) -> Tensor: ...
448+
@overload
449+
def zeros_like(
450+
input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
451+
) -> RaggedTensor: ...
452+
@overload
453+
def ones_like(
454+
input: TensorCompatible, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
455+
) -> Tensor: ...
456+
@overload
457+
def ones_like(
458+
input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
459+
) -> RaggedTensor: ...
441460
def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ...
442461
def __getattr__(name: str) -> Incomplete: ...

stubs/tensorflow/tensorflow/experimental/dtensor.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ from _typeshed import Incomplete
22

33
from tensorflow._aliases import IntArray, IntDataSequence
44

5+
Layout = Incomplete
6+
57
class Mesh:
68
def __init__(
79
self,

0 commit comments

Comments
 (0)