@@ -12,13 +12,6 @@ from numpy._typing import (
1212 _ArrayLikeInt_co ,
1313 _ArrayLikeObject_co ,
1414 _ArrayLikeUInt_co ,
15- _DTypeLikeBool ,
16- _DTypeLikeComplex ,
17- _DTypeLikeComplex_co ,
18- _DTypeLikeFloat ,
19- _DTypeLikeInt ,
20- _DTypeLikeObject ,
21- _DTypeLikeUInt ,
2215)
2316
2417__all__ = ["einsum" , "einsum_path" ]
@@ -30,6 +23,14 @@ _OptimizeKind: TypeAlias = bool | Literal["greedy", "optimal"] | Sequence[str |
3023_CastingSafe : TypeAlias = Literal ["no" , "equiv" , "safe" , "same_kind" , "same_value" ]
3124_CastingUnsafe : TypeAlias = Literal ["unsafe" ]
3225
26+ _ToDTypeUInt : TypeAlias = (
27+ _nt .ToDTypeUInt8 | _nt .ToDTypeUInt16 | _nt .ToDTypeUInt32 | _nt .ToDTypeUInt64 | _nt .ToDTypeULong
28+ )
29+ _ToDTypeInt : TypeAlias = _nt .ToDTypeInt8 | _nt .ToDTypeInt16 | _nt .ToDTypeInt32 | _nt .ToDTypeInt64 | _nt .ToDTypeLong
30+ _ToDTypeFloat : TypeAlias = _nt .ToDTypeFloat16 | _nt .ToDTypeFloat32 | _nt .ToDTypeFloat64 | _nt .ToDTypeLongDouble
31+ _ToDTypeComplex : TypeAlias = _nt .ToDTypeComplex64 | _nt .ToDTypeComplex128 | _nt .ToDTypeCLongDouble
32+ _ToDTypeComplex_co : TypeAlias = _nt .ToDTypeBool | _ToDTypeUInt | _ToDTypeInt | _ToDTypeFloat | _ToDTypeComplex
33+
3334# TODO: Properly handle the `casting`-based combinatorics
3435# TODO: We need to evaluate the content `__subscripts` in order
3536# to identify whether or an array or scalar is returned. At a cursory
@@ -43,7 +44,7 @@ def einsum(
4344 * operands : _ArrayLikeBool_co ,
4445 out : None = None ,
4546 optimize : _OptimizeKind = False ,
46- dtype : _DTypeLikeBool | None = None ,
47+ dtype : _nt . ToDTypeBool | None = None ,
4748 order : _OrderKACF = "K" ,
4849 casting : _CastingSafe = "safe" ,
4950) -> Incomplete : ...
@@ -53,7 +54,7 @@ def einsum(
5354 / ,
5455 * operands : _ArrayLikeUInt_co ,
5556 out : None = None ,
56- dtype : _DTypeLikeUInt | None = None ,
57+ dtype : _ToDTypeUInt | None = None ,
5758 order : _OrderKACF = "K" ,
5859 casting : _CastingSafe = "safe" ,
5960 optimize : _OptimizeKind = False ,
@@ -64,7 +65,7 @@ def einsum(
6465 / ,
6566 * operands : _ArrayLikeInt_co ,
6667 out : None = None ,
67- dtype : _DTypeLikeInt | None = None ,
68+ dtype : _ToDTypeInt | None = None ,
6869 order : _OrderKACF = "K" ,
6970 casting : _CastingSafe = "safe" ,
7071 optimize : _OptimizeKind = False ,
@@ -75,7 +76,7 @@ def einsum(
7576 / ,
7677 * operands : _ArrayLikeFloat_co ,
7778 out : None = None ,
78- dtype : _DTypeLikeFloat | None = None ,
79+ dtype : _ToDTypeFloat | None = None ,
7980 order : _OrderKACF = "K" ,
8081 casting : _CastingSafe = "safe" ,
8182 optimize : _OptimizeKind = False ,
@@ -86,7 +87,7 @@ def einsum(
8687 / ,
8788 * operands : _ArrayLikeComplex_co ,
8889 out : None = None ,
89- dtype : _DTypeLikeComplex | None = None ,
90+ dtype : _ToDTypeComplex | None = None ,
9091 order : _OrderKACF = "K" ,
9192 casting : _CastingSafe = "safe" ,
9293 optimize : _OptimizeKind = False ,
@@ -97,7 +98,7 @@ def einsum(
9798 / ,
9899 * operands : Any ,
99100 casting : _CastingUnsafe ,
100- dtype : _DTypeLikeComplex_co | None = None ,
101+ dtype : _ToDTypeComplex_co | None = None ,
101102 out : None = None ,
102103 order : _OrderKACF = "K" ,
103104 optimize : _OptimizeKind = False ,
@@ -108,7 +109,7 @@ def einsum(
108109 / ,
109110 * operands : _ArrayLikeComplex_co ,
110111 out : _ArrayT ,
111- dtype : _DTypeLikeComplex_co | None = None ,
112+ dtype : _ToDTypeComplex_co | None = None ,
112113 order : _OrderKACF = "K" ,
113114 casting : _CastingSafe = "safe" ,
114115 optimize : _OptimizeKind = False ,
@@ -120,7 +121,7 @@ def einsum(
120121 * operands : Any ,
121122 out : _ArrayT ,
122123 casting : _CastingUnsafe ,
123- dtype : _DTypeLikeComplex_co | None = None ,
124+ dtype : _ToDTypeComplex_co | None = None ,
124125 order : _OrderKACF = "K" ,
125126 optimize : _OptimizeKind = False ,
126127) -> _ArrayT : ...
@@ -130,7 +131,7 @@ def einsum(
130131 / ,
131132 * operands : _ArrayLikeObject_co ,
132133 out : None = None ,
133- dtype : _DTypeLikeObject | None = None ,
134+ dtype : _nt . ToDTypeObject | None = None ,
134135 order : _OrderKACF = "K" ,
135136 casting : _CastingSafe = "safe" ,
136137 optimize : _OptimizeKind = False ,
@@ -141,7 +142,7 @@ def einsum(
141142 / ,
142143 * operands : Any ,
143144 casting : _CastingUnsafe ,
144- dtype : _DTypeLikeObject | None = None ,
145+ dtype : _nt . ToDTypeObject | None = None ,
145146 out : None = None ,
146147 order : _OrderKACF = "K" ,
147148 optimize : _OptimizeKind = False ,
@@ -152,7 +153,7 @@ def einsum(
152153 / ,
153154 * operands : _ArrayLikeObject_co ,
154155 out : _ArrayT ,
155- dtype : _DTypeLikeObject | None = None ,
156+ dtype : _nt . ToDTypeObject | None = None ,
156157 order : _OrderKACF = "K" ,
157158 casting : _CastingSafe = "safe" ,
158159 optimize : _OptimizeKind = False ,
@@ -164,7 +165,7 @@ def einsum(
164165 * operands : Any ,
165166 out : _ArrayT ,
166167 casting : _CastingUnsafe ,
167- dtype : _DTypeLikeObject | None = None ,
168+ dtype : _nt . ToDTypeObject | None = None ,
168169 order : _OrderKACF = "K" ,
169170 optimize : _OptimizeKind = False ,
170171) -> _ArrayT : ...
@@ -175,7 +176,7 @@ def einsum(
175176def einsum_path (
176177 subscripts : str | _ArrayLikeInt_co ,
177178 / ,
178- * operands : _ArrayLikeComplex_co | _DTypeLikeObject ,
179+ * operands : _ArrayLikeComplex_co | _nt . ToDTypeObject ,
179180 optimize : _OptimizeKind = "greedy" ,
180181 einsum_call : L [False ] = False ,
181182) -> tuple [list [str | tuple [int , ...]], str ]: ...
0 commit comments