Skip to content

Commit d93ee88

Browse files
authored
tensorflow add tf.random module (#11359)
Partially from Mehdi Drissi's stubs.
1 parent c9f74e6 commit d93ee88

File tree

2 files changed

+232
-1
lines changed

2 files changed

+232
-1
lines changed

stubs/tensorflow/tensorflow/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,6 @@ class RaggedTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):
406406
@classmethod
407407
def from_value(cls, value: RaggedTensor) -> Self: ...
408408

409-
def __getattr__(name: str) -> Incomplete: ...
410409
def convert_to_tensor(
411410
value: TensorCompatible | IndexedSlices,
412411
dtype: DTypeLike | None = None,
@@ -440,3 +439,4 @@ def cast(x: SparseTensor, dtype: DTypeLike, name: str | None = None) -> SparseTe
440439
@overload
441440
def cast(x: RaggedTensor, dtype: DTypeLike, name: str | None = None) -> RaggedTensor: ...
442441
def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ...
442+
def __getattr__(name: str) -> Incomplete: ...
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
from collections.abc import Sequence
2+
from enum import Enum
3+
from typing import Literal
4+
from typing_extensions import TypeAlias
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
import tensorflow as tf
9+
from tensorflow._aliases import DTypeLike, ScalarTensorCompatible, ShapeLike
10+
from tensorflow.python.trackable import autotrackable
11+
12+
class Algorithm(Enum):
13+
PHILOX = 1
14+
THREEFRY = 2
15+
AUTO_SELECT = 3
16+
17+
_Alg: TypeAlias = Literal[Algorithm.PHILOX, Algorithm.THREEFRY, Algorithm.AUTO_SELECT, "philox", "threefry", "auto_select"]
18+
19+
class Generator(autotrackable.AutoTrackable):
20+
@classmethod
21+
def from_state(cls, state: tf.Variable, alg: _Alg | None) -> Generator: ...
22+
@classmethod
23+
def from_seed(cls, seed: int, alg: _Alg | None = None) -> Generator: ...
24+
@classmethod
25+
def from_non_deterministic_state(cls, alg: _Alg | None = None) -> Generator: ...
26+
@classmethod
27+
def from_key_counter(
28+
cls, key: ScalarTensorCompatible, counter: Sequence[ScalarTensorCompatible], alg: _Alg | None
29+
) -> Generator: ...
30+
def __init__(self, copy_from: Generator | None = None, state: tf.Variable | None = None, alg: _Alg | None = None) -> None: ...
31+
def reset(self, state: tf.Variable) -> None: ...
32+
def reset_from_seed(self, seed: int) -> None: ...
33+
def reset_from_key_counter(self, key: ScalarTensorCompatible, counter: tf.Variable) -> None: ...
34+
@property
35+
def state(self) -> tf.Variable: ...
36+
@property
37+
def algorithm(self) -> int: ...
38+
@property
39+
def key(self) -> ScalarTensorCompatible: ...
40+
def skip(self, delta: int) -> tf.Tensor: ...
41+
def normal(
42+
self,
43+
shape: tf.Tensor | Sequence[int],
44+
mean: ScalarTensorCompatible = 0.0,
45+
stddev: ScalarTensorCompatible = 1.0,
46+
dtype: DTypeLike = ...,
47+
name: str | None = None,
48+
) -> tf.Tensor: ...
49+
def truncated_normal(
50+
self,
51+
shape: ShapeLike,
52+
mean: ScalarTensorCompatible = 0.0,
53+
stddev: ScalarTensorCompatible = 1.0,
54+
dtype: DTypeLike = ...,
55+
name: str | None = None,
56+
) -> tf.Tensor: ...
57+
def uniform(
58+
self,
59+
shape: ShapeLike,
60+
minval: ScalarTensorCompatible = 0,
61+
maxval: ScalarTensorCompatible | None = None,
62+
dtype: DTypeLike = ...,
63+
name: str | None = None,
64+
) -> tf.Tensor: ...
65+
def uniform_full_int(self, shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None) -> tf.Tensor: ...
66+
def binomial(
67+
self, shape: ShapeLike, counts: tf.Tensor, probs: tf.Tensor, dtype: DTypeLike = ..., name: str | None = None
68+
) -> tf.Tensor: ...
69+
def make_seeds(self, count: int = 1) -> tf.Tensor: ...
70+
def split(self, count: int = 1) -> list[Generator]: ...
71+
72+
def all_candidate_sampler(
73+
true_classes: tf.Tensor, num_true: int, num_sampled: int, unique: bool, seed: int | None = None, name: str | None = None
74+
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: ...
75+
def categorical(
76+
logits: tf.Tensor,
77+
num_samples: int | tf.Tensor,
78+
dtype: DTypeLike | None = None,
79+
seed: int | None = None,
80+
name: str | None = None,
81+
) -> tf.Tensor: ...
82+
def create_rng_state(seed: int, alg: _Alg) -> npt.NDArray[np.int64]: ...
83+
def fixed_unigram_candidate_sampler(
84+
true_classes: tf.Tensor,
85+
num_true: int,
86+
num_sampled: int,
87+
unique: bool,
88+
range_max: int,
89+
vocab_file: str = "",
90+
distortion: float = 1.0,
91+
num_reserved_ids: int = 0,
92+
num_shards: int = 1,
93+
shard: int = 0,
94+
unigrams: Sequence[float] = (),
95+
seed: int | None = None,
96+
name: str | None = None,
97+
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: ...
98+
def fold_in(seed: tf.Tensor | Sequence[int], data: int, alg: _Alg = "auto_select") -> int: ...
99+
def gamma(
100+
shape: tf.Tensor | Sequence[int],
101+
alpha: tf.Tensor | float | Sequence[float],
102+
beta: tf.Tensor | float | Sequence[float] | None = None,
103+
dtype: DTypeLike = ...,
104+
seed: int | None = None,
105+
name: str | None = None,
106+
) -> tf.Tensor: ...
107+
def get_global_generator() -> Generator: ...
108+
def learned_unigram_candidate_sampler(
109+
true_classes: tf.Tensor,
110+
num_true: int,
111+
num_sampled: int,
112+
unique: bool,
113+
range_max: int,
114+
seed: int | None = None,
115+
name: str | None = None,
116+
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: ...
117+
def log_uniform_candidate_sampler(
118+
true_classes: tf.Tensor,
119+
num_true: int,
120+
num_sampled: int,
121+
unique: bool,
122+
range_max: int,
123+
seed: int | None = None,
124+
name: str | None = None,
125+
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: ...
126+
def normal(
127+
shape: ShapeLike,
128+
mean: ScalarTensorCompatible = 0.0,
129+
stddev: ScalarTensorCompatible = 1.0,
130+
dtype: DTypeLike = ...,
131+
seed: int | None = None,
132+
name: str | None = None,
133+
) -> tf.Tensor: ...
134+
def poisson(
135+
shape: ShapeLike, lam: ScalarTensorCompatible, dtype: DTypeLike = ..., seed: int | None = None, name: str | None = None
136+
) -> tf.Tensor: ...
137+
def set_global_generator(generator: Generator) -> None: ...
138+
def set_seed(seed: int) -> None: ...
139+
def shuffle(value: tf.Tensor, seed: int | None = None, name: str | None = None) -> tf.Tensor: ...
140+
def split(seed: tf.Tensor | Sequence[int], num: int = 2, alg: _Alg = "auto_select") -> tf.Tensor: ...
141+
def stateless_binomial(
142+
shape: ShapeLike,
143+
seed: tuple[int, int] | tf.Tensor,
144+
counts: tf.Tensor,
145+
probs: tf.Tensor,
146+
output_dtype: DTypeLike = ...,
147+
name: str | None = None,
148+
) -> tf.Tensor: ...
149+
def stateless_categorical(
150+
logits: tf.Tensor,
151+
num_samples: int | tf.Tensor,
152+
seed: tuple[int, int] | tf.Tensor,
153+
dtype: DTypeLike = ...,
154+
name: str | None = None,
155+
) -> tf.Tensor: ...
156+
def stateless_gamma(
157+
shape: ShapeLike,
158+
seed: tuple[int, int] | tf.Tensor,
159+
alpha: tf.Tensor,
160+
beta: tf.Tensor | None = None,
161+
dtype: DTypeLike = ...,
162+
name: str | None = None,
163+
) -> tf.Tensor: ...
164+
def stateless_normal(
165+
shape: tf.Tensor | Sequence[int],
166+
seed: tuple[int, int] | tf.Tensor,
167+
mean: float | tf.Tensor = 0.0,
168+
stddev: float | tf.Tensor = 1.0,
169+
dtype: DTypeLike = ...,
170+
name: str | None = None,
171+
alg: _Alg = "auto_select",
172+
) -> tf.Tensor: ...
173+
def stateless_parameterized_truncated_normal(
174+
shape: tf.Tensor | Sequence[int],
175+
seed: tuple[int, int] | tf.Tensor,
176+
means: float | tf.Tensor = 0.0,
177+
stddevs: float | tf.Tensor = 1.0,
178+
minvals: tf.Tensor | float = -2.0,
179+
maxvals: tf.Tensor | float = 2.0,
180+
name: str | None = None,
181+
) -> tf.Tensor: ...
182+
def stateless_poisson(
183+
shape: tf.Tensor | Sequence[int],
184+
seed: tuple[int, int] | tf.Tensor,
185+
lam: tf.Tensor,
186+
dtype: DTypeLike = ...,
187+
name: str | None = None,
188+
) -> tf.Tensor: ...
189+
def stateless_truncated_normal(
190+
shape: tf.Tensor | Sequence[int],
191+
seed: tuple[int, int] | tf.Tensor,
192+
mean: float | tf.Tensor = 0.0,
193+
stddev: float | tf.Tensor = 1.0,
194+
dtype: DTypeLike = ...,
195+
name: str | None = None,
196+
alg: _Alg = "auto_select",
197+
) -> tf.Tensor: ...
198+
def stateless_uniform(
199+
shape: tf.Tensor | Sequence[int],
200+
seed: tuple[int, int] | tf.Tensor,
201+
minval: float | tf.Tensor = 0,
202+
maxval: float | tf.Tensor | None = None,
203+
dtype: DTypeLike = ...,
204+
name: str | None = None,
205+
alg: _Alg = "auto_select",
206+
) -> tf.Tensor: ...
207+
def truncated_normal(
208+
shape: tf.Tensor | Sequence[int],
209+
mean: float | tf.Tensor = 0.0,
210+
stddev: float | tf.Tensor = 1.0,
211+
dtype: DTypeLike = ...,
212+
seed: int | None = None,
213+
name: str | None = None,
214+
) -> tf.Tensor: ...
215+
def uniform(
216+
shape: tf.Tensor | Sequence[int],
217+
minval: float | tf.Tensor = 0,
218+
maxval: float | tf.Tensor | None = None,
219+
dtype: DTypeLike = ...,
220+
seed: int | None = None,
221+
name: str | None = None,
222+
) -> tf.Tensor: ...
223+
def uniform_candidate_sampler(
224+
true_classes: tf.Tensor,
225+
num_true: int,
226+
num_sampled: int,
227+
unique: bool,
228+
range_max: int,
229+
seed: int | None = None,
230+
name: str | None = None,
231+
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: ...

0 commit comments

Comments
 (0)