Skip to content

Commit b14096a

Browse files
update stub files
- Added default values in stub files. - Added ModeInt type union to allow integer values where MODE enum was expected. - Used upper bounds for DataT and CDataT types instead of variable constraints.
1 parent 5333fd7 commit b14096a

3 files changed

Lines changed: 23 additions & 23 deletions

File tree

pywt/_extensions/_dwt.pyi

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from numpy.typing import NDArray
22

3-
from pywt import MODE, CDataT, Wavelet
3+
from ._pywt import CDataT, ModeInt, Wavelet
44

55
def dwt_max_level(data_len: int, filter_len: int) -> int: ...
6-
def dwt_coeff_len(size_t: int, filter_len: int, mode: MODE) -> int: ...
7-
def dwt_single(data: NDArray[CDataT], wavelet: Wavelet, mode: MODE) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ...
8-
def dwt_axis(data: NDArray[CDataT], wavelet: Wavelet, mode: MODE, axis: int = ...) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ...
9-
def idwt_single(cA: NDArray[CDataT], cD: NDArray[CDataT], wavelet: Wavelet, mode: MODE) -> NDArray[CDataT]: ...
10-
def idwt_axis(coefs_a: NDArray[CDataT], coefs_d: NDArray[CDataT], wavelet: Wavelet, mode: MODE, axis: int = ...) -> NDArray[CDataT]: ...
6+
def dwt_coeff_len(size_t: int, filter_len: int, mode: ModeInt) -> int: ...
7+
def dwt_single(data: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ...
8+
def dwt_axis(data: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt, axis: int = 0) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ...
9+
def idwt_single(cA: NDArray[CDataT], cD: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt) -> NDArray[CDataT]: ...
10+
def idwt_axis(coefs_a: NDArray[CDataT], coefs_d: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt, axis: int = 0) -> NDArray[CDataT]: ...
1111
def upcoef(do_rec_a: bool, coeffs: NDArray[CDataT], wavelet: Wavelet, level: int, take: int) -> NDArray[CDataT]: ...
12-
def downcoef(do_dec_a: bool, data: NDArray[CDataT], wavelet: Wavelet, mode: MODE, level: int) -> NDArray[CDataT]: ...
12+
def downcoef(do_dec_a: bool, data: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt, level: int) -> NDArray[CDataT]: ...

pywt/_extensions/_pywt.pyi

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import IntEnum
2-
from typing import Any, Literal, Optional, TypeVar
2+
from typing import Any, Literal, Optional, TypeAlias, TypeVar
33

44
import numpy as np
55

@@ -20,17 +20,13 @@ _WaveletFamily = Literal[
2020
"cmor",
2121
]
2222

23-
DataT = TypeVar("DataT", np.float32, np.float64)
23+
DataT = TypeVar("DataT", bound=np.float32 | np.float64)
2424

2525
CDataT = TypeVar(
26-
"CDataT",
27-
np.float32,
28-
np.float64,
29-
np.complex64,
30-
np.complex128,
26+
"CDataT", bound=np.float32 | np.float64 | np.complex64 | np.complex128
3127
)
3228

33-
_Kind = Literal["all", "continuous", "discrete"]
29+
_Kind: TypeAlias = Literal["all", "continuous", "discrete"]
3430

3531
_Symmetry = Literal[
3632
"asymmetric",
@@ -53,6 +49,8 @@ class MODE(IntEnum):
5349
MODE_ANTIREFLECT = 8
5450
MODE_MAX = 9
5551

52+
ModeInt = MODE | Literal[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
53+
5654
ModeName = Literal[
5755
"zero",
5856
"constant",
@@ -84,11 +82,11 @@ class _Modes:
8482

8583
Modes = _Modes()
8684

87-
def wavelist(family: _WaveletFamily | None = ..., kind: _Kind = ...) -> list[str]: ...
88-
def families(short: bool = ...) -> list[str]: ...
85+
def wavelist(family: _WaveletFamily | None = None, kind: _Kind = "all") -> list[str]: ...
86+
def families(short: bool = True) -> list[str]: ...
8987

9088
class Wavelet:
91-
def __init__(self, name: str = ..., filter_bank: Any = ...) -> None: ...
89+
def __init__(self, name: str = "", filter_bank: Any = None) -> None: ...
9290
def __len__(self) -> int: ...
9391
@property
9492
def name(self) -> str: ...
@@ -140,7 +138,7 @@ class Wavelet:
140138
) -> tuple[list[float], list[float], list[float], list[float]]: ...
141139

142140
class ContinuousWavelet:
143-
def __init__(self, name: str = ..., dtype: DataT = ...) -> None: ...
141+
def __init__(self, name: str = "", dtype: DataT = np.float64) -> None: ...
144142
@property
145143
def family_number(self) -> int: ...
146144
@property
@@ -182,4 +180,6 @@ class ContinuousWavelet:
182180
@property
183181
def symmetry(self) -> _Symmetry: ...
184182

185-
def DiscreteContinuousWavelet(name: str = ..., filter_bank: Any = ...) -> Wavelet | ContinuousWavelet : ...
183+
def DiscreteContinuousWavelet(
184+
name: str = "", filter_bank: Any = None
185+
) -> Wavelet | ContinuousWavelet: ...

pywt/_extensions/_swt.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from numpy.typing import NDArray
22

3-
from pywt import CDataT, Wavelet
3+
from ._pywt import CDataT, Wavelet
44

55
def swt_max_level(input_len: int) -> int: ...
6-
def swt(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, trim_approx: bool = ...) -> NDArray[CDataT]: ...
7-
def swt_axis(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, axis: int = ..., trim_approx: bool = ...) -> NDArray[CDataT]: ...
6+
def swt(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, trim_approx: bool = False) -> NDArray[CDataT]: ...
7+
def swt_axis(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, axis: int = 0, trim_approx: bool = False) -> NDArray[CDataT]: ...

0 commit comments

Comments
 (0)