Skip to content

Commit 4353597

Browse files
committed
improve types according to test/test_filter.py
1 parent 3e651c8 commit 4353597

File tree

3 files changed

+38
-18
lines changed

3 files changed

+38
-18
lines changed

pygit2/_pygit2.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ from .enums import (
5252
ResetMode,
5353
SortMode,
5454
)
55+
from .filter import Filter
5556
from .remotes import Remote
5657
from .repository import BaseRepository
5758
from .submodules import SubmoduleCollection
@@ -522,6 +523,7 @@ class DiffStats:
522523

523524
class FilterSource:
524525
# probably incomplete
526+
repo: object
525527
pass
526528

527529
class GitError(Exception): ...
@@ -1036,5 +1038,7 @@ def option(opt: Option, *args) -> None: ...
10361038
def reference_is_valid_name(refname: str) -> bool: ...
10371039
def tree_entry_cmp(a: Object, b: Object) -> int: ...
10381040
def _cache_enums() -> None: ...
1041+
def filter_register(name: str, filter: type[Filter]) -> None: ...
1042+
def filter_unregister(name: str) -> None: ...
10391043

10401044
_OidArg = str | Oid

pygit2/filter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class Filter:
5858
def nattrs(cls) -> int:
5959
return len(cls.attributes.split())
6060

61-
def check(self, src: FilterSource, attr_values: List[Optional[str]]):
61+
def check(self, src: FilterSource, attr_values: List[Optional[str]]) -> None:
6262
"""
6363
Check whether this filter should be applied to the given source.
6464
@@ -77,7 +77,7 @@ def check(self, src: FilterSource, attr_values: List[Optional[str]]):
7777

7878
def write(
7979
self, data: bytes, src: FilterSource, write_next: Callable[[bytes], None]
80-
):
80+
) -> None:
8181
"""
8282
Write input `data` to this filter.
8383
@@ -95,7 +95,7 @@ def write(
9595
"""
9696
write_next(data)
9797

98-
def close(self, write_next: Callable[[bytes], None]):
98+
def close(self, write_next: Callable[[bytes], None]) -> None:
9999
"""
100100
Close this filter.
101101

test/test_filter.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,52 @@
11
import codecs
22
from io import BytesIO
3+
from typing import Callable, Generator
34

45
import pytest
56

67
import pygit2
8+
from pygit2 import Blob, Filter, FilterSource, Repository
79
from pygit2.enums import BlobFilter
810
from pygit2.errors import Passthrough
911

1012

11-
def _rot13(data):
13+
def _rot13(data: bytes) -> bytes:
1214
return codecs.encode(data.decode('utf-8'), 'rot_13').encode('utf-8')
1315

1416

1517
class _Rot13Filter(pygit2.Filter):
1618
attributes = 'text'
1719

18-
def write(self, data, src, write_next):
20+
def write(
21+
self,
22+
data: bytes,
23+
src: FilterSource,
24+
write_next: Callable[[bytes], None],
25+
) -> None:
1926
return super().write(_rot13(data), src, write_next)
2027

2128

2229
class _BufferedFilter(pygit2.Filter):
2330
attributes = 'text'
2431

25-
def __init__(self):
32+
def __init__(self) -> None:
2633
super().__init__()
2734
self.buf = BytesIO()
2835

29-
def write(self, data, src, write_next):
36+
def write(
37+
self,
38+
data: bytes,
39+
src: FilterSource,
40+
write_next: Callable[[bytes], None],
41+
) -> None:
3042
self.buf.write(data)
3143

32-
def close(self, write_next):
44+
def close(self, write_next: Callable[[bytes], None]) -> None:
3345
write_next(_rot13(self.buf.getvalue()))
3446

3547

3648
class _PassthroughFilter(_Rot13Filter):
37-
def check(self, src, attr_values):
49+
def check(self, src: FilterSource, attr_values: list[str | None]) -> None:
3850
assert attr_values == [None]
3951
assert src.repo
4052
raise Passthrough
@@ -45,36 +57,37 @@ class _UnmatchedFilter(_Rot13Filter):
4557

4658

4759
@pytest.fixture
48-
def rot13_filter():
60+
def rot13_filter() -> Generator[None, None, None]:
4961
pygit2.filter_register('rot13', _Rot13Filter)
5062
yield
5163
pygit2.filter_unregister('rot13')
5264

5365

5466
@pytest.fixture
55-
def passthrough_filter():
67+
def passthrough_filter() -> Generator[None, None, None]:
5668
pygit2.filter_register('passthrough-rot13', _PassthroughFilter)
5769
yield
5870
pygit2.filter_unregister('passthrough-rot13')
5971

6072

6173
@pytest.fixture
62-
def buffered_filter():
74+
def buffered_filter() -> Generator[None, None, None]:
6375
pygit2.filter_register('buffered-rot13', _BufferedFilter)
6476
yield
6577
pygit2.filter_unregister('buffered-rot13')
6678

6779

6880
@pytest.fixture
69-
def unmatched_filter():
81+
def unmatched_filter() -> Generator[None, None, None]:
7082
pygit2.filter_register('unmatched-rot13', _UnmatchedFilter)
7183
yield
7284
pygit2.filter_unregister('unmatched-rot13')
7385

7486

75-
def test_filter(testrepo, rot13_filter):
87+
def test_filter(testrepo: Repository, rot13_filter: Filter) -> None:
7688
blob_oid = testrepo.create_blob_fromworkdir('bye.txt')
7789
blob = testrepo[blob_oid]
90+
assert isinstance(blob, Blob)
7891
flags = BlobFilter.CHECK_FOR_BINARY | BlobFilter.ATTRIBUTES_FROM_HEAD
7992
assert b'olr jbeyq\n' == blob.data
8093
with pygit2.BlobIO(blob) as reader:
@@ -83,9 +96,10 @@ def test_filter(testrepo, rot13_filter):
8396
assert b'bye world\n' == reader.read()
8497

8598

86-
def test_filter_buffered(testrepo, buffered_filter):
99+
def test_filter_buffered(testrepo: Repository, buffered_filter: Filter) -> None:
87100
blob_oid = testrepo.create_blob_fromworkdir('bye.txt')
88101
blob = testrepo[blob_oid]
102+
assert isinstance(blob, Blob)
89103
flags = BlobFilter.CHECK_FOR_BINARY | BlobFilter.ATTRIBUTES_FROM_HEAD
90104
assert b'olr jbeyq\n' == blob.data
91105
with pygit2.BlobIO(blob) as reader:
@@ -94,9 +108,10 @@ def test_filter_buffered(testrepo, buffered_filter):
94108
assert b'bye world\n' == reader.read()
95109

96110

97-
def test_filter_passthrough(testrepo, passthrough_filter):
111+
def test_filter_passthrough(testrepo: Repository, passthrough_filter: Filter) -> None:
98112
blob_oid = testrepo.create_blob_fromworkdir('bye.txt')
99113
blob = testrepo[blob_oid]
114+
assert isinstance(blob, Blob)
100115
flags = BlobFilter.CHECK_FOR_BINARY | BlobFilter.ATTRIBUTES_FROM_HEAD
101116
assert b'bye world\n' == blob.data
102117
with pygit2.BlobIO(blob) as reader:
@@ -105,9 +120,10 @@ def test_filter_passthrough(testrepo, passthrough_filter):
105120
assert b'bye world\n' == reader.read()
106121

107122

108-
def test_filter_unmatched(testrepo, unmatched_filter):
123+
def test_filter_unmatched(testrepo: Repository, unmatched_filter: Filter) -> None:
109124
blob_oid = testrepo.create_blob_fromworkdir('bye.txt')
110125
blob = testrepo[blob_oid]
126+
assert isinstance(blob, Blob)
111127
flags = BlobFilter.CHECK_FOR_BINARY | BlobFilter.ATTRIBUTES_FROM_HEAD
112128
assert b'bye world\n' == blob.data
113129
with pygit2.BlobIO(blob) as reader:
@@ -116,7 +132,7 @@ def test_filter_unmatched(testrepo, unmatched_filter):
116132
assert b'bye world\n' == reader.read()
117133

118134

119-
def test_filter_cleanup(dirtyrepo, rot13_filter):
135+
def test_filter_cleanup(dirtyrepo: Repository, rot13_filter: Filter) -> None:
120136
# Indirectly test that pygit2_filter_cleanup has the GIL
121137
# before calling pygit2_filter_payload_free.
122138
dirtyrepo.diff()

0 commit comments

Comments
 (0)