Skip to content

Commit b5f4588

Browse files
committed
overload SoundFile.read(). add dtype_str TypeAlias to methods.
1 parent 3fafd01 commit b5f4588

1 file changed

Lines changed: 16 additions & 8 deletions

File tree

soundfile.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -886,9 +886,17 @@ def tell(self) -> int:
886886
return self.seek(0, SEEK_CUR)
887887

888888

889-
def read(self, frames: int = -1, dtype: str = 'float64',
890-
always_2d: bool = False, fill_value: float | None = None,
891-
out: AudioData | None = None) -> AudioData:
889+
@overload
890+
def read(self, frames: int = -1, dtype: dtype_str = 'float64',
891+
*, always_2d: Literal[True], fill_value: float | None = None,
892+
out: AudioData_2d | None = None) -> AudioData_2d:...
893+
@overload
894+
def read(self, frames: int = -1, dtype: dtype_str = 'float64',
895+
always_2d: bool = False, fill_value: float | None = None,
896+
out: AudioData | None = None) -> AudioData:...
897+
def read(self, frames: int = -1, dtype: dtype_str = 'float64',
898+
always_2d: bool = False, fill_value: float | None = None,
899+
out: AudioData | None = None) -> AudioData | AudioData_2d:
892900
"""Read from the file and return data as NumPy array.
893901
894902
Reads the given number of frames in the given data format
@@ -982,7 +990,7 @@ def read(self, frames: int = -1, dtype: str = 'float64',
982990
return out
983991

984992

985-
def buffer_read(self, frames: int = -1, dtype: str | None = None) -> memoryview:
993+
def buffer_read(self, frames: int = -1, dtype: dtype_str | None = None) -> memoryview:
986994
"""Read from the file and return data as buffer object.
987995
988996
Reads the given number of *frames* in the given data format
@@ -1017,7 +1025,7 @@ def buffer_read(self, frames: int = -1, dtype: str | None = None) -> memoryview:
10171025
assert read_frames == frames
10181026
return _ffi.buffer(cdata)
10191027

1020-
def buffer_read_into(self, buffer: bytearray | memoryview | Any, dtype: str) -> int:
1028+
def buffer_read_into(self, buffer: bytearray | memoryview | Any, dtype: dtype_str) -> int:
10211029
"""Read from the file into a given buffer object.
10221030
10231031
Fills the given *buffer* with frames in the given data format
@@ -1104,7 +1112,7 @@ def write(self, data: AudioData) -> None:
11041112
assert written == len(data)
11051113
self._update_frames(written)
11061114

1107-
def buffer_write(self, data: Any, dtype: str) -> None:
1115+
def buffer_write(self, data: bytes, dtype: dtype_str) -> None:
11081116
"""Write audio data from a buffer/bytes object to the file.
11091117
11101118
Writes the contents of *data* to the file at the current
@@ -1132,7 +1140,7 @@ def buffer_write(self, data: Any, dtype: str) -> None:
11321140
self._update_frames(written)
11331141

11341142
def blocks(self, blocksize: int | None = None, overlap: int = 0,
1135-
frames: int = -1, dtype: str = 'float64',
1143+
frames: int = -1, dtype: dtype_str = 'float64',
11361144
always_2d: bool = False, fill_value: float | None = None,
11371145
out: AudioData | None = None) -> Generator[AudioData, None, None]:
11381146
"""Return a generator for block-wise reading.
@@ -1477,7 +1485,7 @@ def _prepare_read(self, start, stop, frames):
14771485
self.seek(start, SEEK_SET)
14781486
return frames
14791487

1480-
def copy_metadata(self):
1488+
def copy_metadata(self) -> dict[str, str]:
14811489
"""Get all metadata present in this SoundFile
14821490
14831491
Returns

0 commit comments

Comments
 (0)