Skip to content

Commit 8061e58

Browse files
tensorflow: add tf.strings module (#11380)
Partially taken from: https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/strings.pyi Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
1 parent 0f0e261 commit 8061e58

File tree

2 files changed

+243
-0
lines changed

2 files changed

+243
-0
lines changed

stubs/tensorflow/tensorflow/_aliases.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence]
3232
IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]
3333
StrDataSequence: TypeAlias = Sequence[str] | Sequence[StrDataSequence]
3434
ScalarTensorCompatible: TypeAlias = tf.Tensor | str | float | np.ndarray[Any, Any] | np.number[Any]
35+
UIntTensorCompatible: TypeAlias = tf.Tensor | int | UIntArray
36+
StringTensorCompatible: TypeAlias = tf.Tensor | str | npt.NDArray[np.str_] | Sequence[StringTensorCompatible]
3537

3638
TensorCompatible: TypeAlias = ScalarTensorCompatible | Sequence[TensorCompatible]
3739
# _TensorCompatibleT = TypeVar("_TensorCompatibleT", bound=TensorCompatible)
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from collections.abc import Sequence
2+
from typing import Literal, TypeVar, overload
3+
4+
from tensorflow import RaggedTensor, Tensor
5+
from tensorflow._aliases import StringTensorCompatible, TensorCompatible, UIntTensorCompatible
6+
from tensorflow.dtypes import DType
7+
8+
_TensorOrRaggedTensor = TypeVar("_TensorOrRaggedTensor", Tensor, RaggedTensor)
9+
10+
@overload
11+
def as_string(
12+
input: TensorCompatible,
13+
precision: int = -1,
14+
scientific: bool = False,
15+
shortest: bool = False,
16+
width: int = -1,
17+
fill: str = "",
18+
name: str | None = None,
19+
) -> Tensor: ...
20+
@overload
21+
def as_string(
22+
input: RaggedTensor,
23+
precision: int = -1,
24+
scientific: bool = False,
25+
shortest: bool = False,
26+
width: int = -1,
27+
fill: str = "",
28+
name: str | None = None,
29+
) -> RaggedTensor: ...
30+
def bytes_split(input: TensorCompatible | RaggedTensor, name: str | None = None) -> RaggedTensor: ...
31+
def format(
32+
template: str, inputs: TensorCompatible, placeholder: str = "{}", summarize: int = 3, name: str | None = None
33+
) -> Tensor: ...
34+
def join(inputs: Sequence[TensorCompatible | RaggedTensor], separator: str = "", name: str | None = None) -> Tensor: ...
35+
@overload
36+
def length(input: TensorCompatible, unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE", name: str | None = None) -> Tensor: ...
37+
@overload
38+
def length(input: RaggedTensor, unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE", name: str | None = None) -> RaggedTensor: ...
39+
@overload
40+
def lower(input: TensorCompatible, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> Tensor: ...
41+
@overload
42+
def lower(input: RaggedTensor, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> RaggedTensor: ...
43+
def ngrams(
44+
data: StringTensorCompatible | RaggedTensor,
45+
ngram_width: int | Sequence[int],
46+
separator: str = " ",
47+
pad_values: tuple[int, int] | str | None = None,
48+
padding_width: int | None = None,
49+
preserve_short_sequences: bool = False,
50+
name: str | None = None,
51+
) -> RaggedTensor: ...
52+
def reduce_join(
53+
inputs: StringTensorCompatible | RaggedTensor,
54+
axis: int | None = None,
55+
keepdims: bool = False,
56+
separator: str = "",
57+
name: str | None = None,
58+
) -> Tensor: ...
59+
@overload
60+
def regex_full_match(input: StringTensorCompatible, pattern: StringTensorCompatible, name: str | None = None) -> Tensor: ...
61+
@overload
62+
def regex_full_match(input: RaggedTensor, pattern: StringTensorCompatible, name: str | None = None) -> RaggedTensor: ...
63+
@overload
64+
def regex_replace(
65+
input: StringTensorCompatible,
66+
pattern: StringTensorCompatible,
67+
rewrite: StringTensorCompatible,
68+
replace_global: bool = True,
69+
name: str | None = None,
70+
) -> Tensor: ...
71+
@overload
72+
def regex_replace(
73+
input: RaggedTensor,
74+
pattern: StringTensorCompatible,
75+
rewrite: StringTensorCompatible,
76+
replace_global: bool = True,
77+
name: str | None = None,
78+
) -> RaggedTensor: ...
79+
def split(
80+
input: StringTensorCompatible | RaggedTensor,
81+
sep: StringTensorCompatible | None = None,
82+
maxsplit: int = -1,
83+
name: str | None = None,
84+
) -> RaggedTensor: ...
85+
@overload
86+
def strip(input: StringTensorCompatible, name: str | None = None) -> Tensor: ...
87+
@overload
88+
def strip(input: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
89+
@overload
90+
def substr(
91+
input: StringTensorCompatible,
92+
pos: TensorCompatible,
93+
len: TensorCompatible,
94+
unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE",
95+
name: str | None = None,
96+
) -> Tensor: ...
97+
@overload
98+
def substr(
99+
input: RaggedTensor,
100+
pos: TensorCompatible,
101+
len: TensorCompatible,
102+
unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE",
103+
name: str | None = None,
104+
) -> RaggedTensor: ...
105+
@overload
106+
def to_hash_bucket(input: StringTensorCompatible, num_buckets: int, name: str | None = None) -> Tensor: ...
107+
@overload
108+
def to_hash_bucket(input: RaggedTensor, num_buckets: int, name: str | None = None) -> RaggedTensor: ...
109+
@overload
110+
def to_hash_bucket_fast(input: StringTensorCompatible, num_buckets: int, name: str | None = None) -> Tensor: ...
111+
@overload
112+
def to_hash_bucket_fast(input: RaggedTensor, num_buckets: int, name: str | None = None) -> RaggedTensor: ...
113+
@overload
114+
def to_hash_bucket_strong(
115+
input: StringTensorCompatible, num_buckets: int, key: Sequence[int], name: str | None = None
116+
) -> Tensor: ...
117+
@overload
118+
def to_hash_bucket_strong(input: RaggedTensor, num_buckets: int, key: Sequence[int], name: str | None = None) -> RaggedTensor: ...
119+
@overload
120+
def to_number(input: StringTensorCompatible, out_type: DType = ..., name: str | None = None) -> Tensor: ...
121+
@overload
122+
def to_number(input: RaggedTensor, out_type: DType = ..., name: str | None = None) -> RaggedTensor: ...
123+
@overload
124+
def unicode_decode(
125+
input: StringTensorCompatible,
126+
input_encoding: str,
127+
errors: Literal["replace", "strict", "ignore"] = "replace",
128+
replacement_char: int = 65533,
129+
replace_control_characters: bool = False,
130+
name: str | None = None,
131+
) -> Tensor | RaggedTensor: ...
132+
@overload
133+
def unicode_decode(
134+
input: RaggedTensor,
135+
input_encoding: str,
136+
errors: Literal["replace", "strict", "ignore"] = "replace",
137+
replacement_char: int = 65533,
138+
replace_control_characters: bool = False,
139+
name: str | None = None,
140+
) -> RaggedTensor: ...
141+
@overload
142+
def unicode_decode_with_offsets(
143+
input: StringTensorCompatible,
144+
input_encoding: str,
145+
errors: Literal["replace", "strict", "ignore"] = "replace",
146+
replacement_char: int = 65533,
147+
replace_control_characters: bool = False,
148+
name: str | None = None,
149+
) -> tuple[_TensorOrRaggedTensor, _TensorOrRaggedTensor]: ...
150+
@overload
151+
def unicode_decode_with_offsets(
152+
input: RaggedTensor,
153+
input_encoding: str,
154+
errors: Literal["replace", "strict", "ignore"] = "replace",
155+
replacement_char: int = 65533,
156+
replace_control_characters: bool = False,
157+
name: str | None = None,
158+
) -> tuple[RaggedTensor, RaggedTensor]: ...
159+
@overload
160+
def unicode_encode(
161+
input: TensorCompatible,
162+
output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"],
163+
errors: Literal["replace", "strict", "ignore"] = "replace",
164+
replacement_char: int = 65533,
165+
name: str | None = None,
166+
) -> Tensor: ...
167+
@overload
168+
def unicode_encode(
169+
input: RaggedTensor,
170+
output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"],
171+
errors: Literal["replace", "strict", "ignore"] = "replace",
172+
replacement_char: int = 65533,
173+
name: str | None = None,
174+
) -> RaggedTensor: ...
175+
@overload
176+
def unicode_script(input: TensorCompatible, name: str | None = None) -> Tensor: ...
177+
@overload
178+
def unicode_script(input: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
179+
@overload
180+
def unicode_split(
181+
input: StringTensorCompatible,
182+
input_encoding: str,
183+
errors: Literal["replace", "strict", "ignore"] = "replace",
184+
replacement_char: int = 65533,
185+
name: str | None = None,
186+
) -> Tensor | RaggedTensor: ...
187+
@overload
188+
def unicode_split(
189+
input: RaggedTensor,
190+
input_encoding: str,
191+
errors: Literal["replace", "strict", "ignore"] = "replace",
192+
replacement_char: int = 65533,
193+
name: str | None = None,
194+
) -> RaggedTensor: ...
195+
@overload
196+
def unicode_split_with_offsets(
197+
input: StringTensorCompatible,
198+
input_encoding: str,
199+
errors: Literal["replace", "strict", "ignore"] = "replace",
200+
replacement_char: int = 65533,
201+
name: str | None = None,
202+
) -> tuple[_TensorOrRaggedTensor, _TensorOrRaggedTensor]: ...
203+
@overload
204+
def unicode_split_with_offsets(
205+
input: RaggedTensor,
206+
input_encoding: str,
207+
errors: Literal["replace", "strict", "ignore"] = "replace",
208+
replacement_char: int = 65533,
209+
name: str | None = None,
210+
) -> tuple[RaggedTensor, RaggedTensor]: ...
211+
@overload
212+
def unicode_transcode(
213+
input: StringTensorCompatible,
214+
input_encoding: str,
215+
output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"],
216+
errors: Literal["replace", "strict", "ignore"] = "replace",
217+
replacement_char: int = 65533,
218+
replace_control_characters: bool = False,
219+
name: str | None = None,
220+
) -> Tensor: ...
221+
@overload
222+
def unicode_transcode(
223+
input: RaggedTensor,
224+
input_encoding: str,
225+
output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"],
226+
errors: Literal["replace", "strict", "ignore"] = "replace",
227+
replacement_char: int = 65533,
228+
replace_control_characters: bool = False,
229+
name: str | None = None,
230+
) -> RaggedTensor: ...
231+
def unsorted_segment_join(
232+
inputs: StringTensorCompatible,
233+
segment_ids: UIntTensorCompatible,
234+
num_segments: UIntTensorCompatible,
235+
separator: str = "",
236+
name: str | None = None,
237+
) -> Tensor: ...
238+
@overload
239+
def upper(input: TensorCompatible, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> Tensor: ...
240+
@overload
241+
def upper(input: RaggedTensor, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> RaggedTensor: ...

0 commit comments

Comments
 (0)