Skip to content

Commit 59e906e

Browse files
committed
typing: Make things stricter (4/4)
Add annotations to all calls. Signed-off-by: Stephen Finucane <stephen@that.guru>
1 parent b17d689 commit 59e906e

7 files changed

Lines changed: 119 additions & 90 deletions

File tree

git_pw/api.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
FC = TypeVar('FC', bound=Callable[..., Any])
2424

25-
Filters = list[tuple[str, str]]
25+
Filters = list[tuple[str, str | int | None]]
2626

2727

2828
class HTTPTokenAuth(requests.auth.AuthBase):
@@ -236,7 +236,7 @@ def version() -> tuple[int, int]:
236236

237237
def get(url: str, params: Filters | None = None) -> dict[str, Any]:
238238
"""Get a JSON document from the API and return it as a dict."""
239-
return _get(url, params, stream=False).json()
239+
return cast(dict[str, Any], _get(url, params, stream=False).json())
240240

241241

242242
def download(
@@ -323,7 +323,7 @@ def index(
323323
params = params or []
324324
params.append(('project', _get_project()))
325325

326-
return _get(url, params).json()
326+
return cast(list[dict[str, Any]], _get(url, params).json())
327327

328328

329329
def detail(
@@ -346,7 +346,7 @@ def detail(
346346
# NOTE(stephenfin): All resources must have a trailing '/'
347347
url = '/'.join([_get_server(), resource_type, str(resource_id), ''])
348348

349-
return _get(url, params, stream=False).json()
349+
return cast(dict[str, Any], _get(url, params, stream=False).json())
350350

351351

352352
def create(
@@ -367,7 +367,7 @@ def create(
367367
# NOTE(stephenfin): All resources must have a trailing '/'
368368
url = '/'.join([_get_server(), resource_type, ''])
369369

370-
return _post(url, data).json()
370+
return cast(dict[str, Any], _post(url, data).json())
371371

372372

373373
def delete(resource_type: str, resource_id: str | int) -> None:
@@ -408,38 +408,38 @@ def update(
408408
# NOTE(stephenfin): All resources must have a trailing '/'
409409
url = '/'.join([_get_server(), resource_type, str(resource_id), ''])
410410

411-
return _patch(url, data).json()
411+
return cast(dict[str, Any], _patch(url, data).json())
412412

413413

414414
def validate_minimum_version(
415415
min_version: tuple[int, int],
416416
msg: str,
417-
) -> Callable[[Any], Any]:
418-
def inner(f):
417+
) -> Callable[[FC], FC]:
418+
def inner(f: FC) -> FC:
419419
@click.pass_context
420-
def new_func(ctx, *args, **kwargs):
420+
def new_func(ctx: click.Context, *args: Any, **kwargs: Any) -> Any:
421421
if version() < min_version:
422422
LOG.error(msg)
423423
sys.exit(1)
424424

425425
return ctx.invoke(f, *args, **kwargs)
426426

427-
return update_wrapper(new_func, f)
427+
return cast(FC, update_wrapper(new_func, f))
428428

429429
return inner
430430

431431

432432
def validate_multiple_filter_support(f: FC) -> FC:
433433
@click.pass_context
434-
def new_func(ctx, *args, **kwargs):
434+
def new_func(ctx: click.Context, *args: Any, **kwargs: Any) -> Any:
435435
if version() >= (1, 1):
436436
return ctx.invoke(f, *args, **kwargs)
437437

438438
for param in ctx.command.params:
439439
if not param.multiple:
440440
continue
441441

442-
if param.name in ('headers'):
442+
if param.name is None or param.name == 'headers':
443443
continue
444444

445445
value = list(kwargs[param.name] or [])

git_pw/bundle.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def download_cmd(bundle_id: str, output: str | None) -> None:
8686
LOG.info('Downloaded bundle to %s', path)
8787

8888

89-
def _show_bundle(bundle: dict[str, Any], fmt: str) -> None:
90-
def _format_patch(patch):
91-
return '%-4d %s' % (patch.get('id'), patch.get('name'))
89+
def _show_bundle(bundle: dict[str, Any], fmt: str | None) -> None:
90+
def _format_patch(patch: dict[str, Any]) -> str:
91+
return '%-4d %s' % (patch['id'], patch.get('name'))
9292

9393
output = [
9494
('ID', bundle.get('id')),
@@ -137,7 +137,15 @@ def show_cmd(fmt: str, bundle_id: str) -> None:
137137
@utils.format_options(headers=_list_headers)
138138
@click.argument('name', required=False)
139139
@api.validate_multiple_filter_support
140-
def list_cmd(owners, limit, page, sort, fmt, headers, name):
140+
def list_cmd(
141+
owners: tuple[str, ...],
142+
limit: int | None,
143+
page: int | None,
144+
sort: str,
145+
fmt: str | None,
146+
headers: tuple[str, ...],
147+
name: str | None,
148+
) -> None:
141149
"""List bundles.
142150
143151
List bundles on the Patchwork instance.
@@ -150,7 +158,7 @@ def list_cmd(owners, limit, page, sort, fmt, headers, name):
150158
sort,
151159
)
152160

153-
params = []
161+
params: list[tuple[str, str | int | None]] = []
154162

155163
for owner in owners:
156164
# we support server-side filtering by username (but not email) in 1.1

git_pw/patch.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Patch subcommands.
33
"""
44

5+
import datetime
56
import logging
67
import os
78
import sys
@@ -70,7 +71,12 @@ def _get_apply_patch_deps() -> bool:
7071
),
7172
)
7273
@click.argument('args', nargs=-1, type=click.UNPROCESSED)
73-
def apply_cmd(patch_id, series, deps, args):
74+
def apply_cmd(
75+
patch_id: int,
76+
series: int | None,
77+
deps: bool,
78+
args: tuple[str, ...],
79+
) -> None:
7480
"""Apply patch.
7581
7682
Apply a patch locally using the 'git-am' command. Any additional ARGS
@@ -86,14 +92,17 @@ def apply_cmd(patch_id, series, deps, args):
8692

8793
patch = api.detail('patches', patch_id)
8894

95+
series_filter: int | str | None = series
8996
if deps and not series:
90-
series = '*'
97+
series_filter = '*'
9198
elif not deps:
92-
series = None
99+
series_filter = None
93100

94101
mbox = api.download(
95102
patch['mbox'],
96-
[('series', str(series))] if series is not None else None,
103+
[('series', str(series_filter))]
104+
if series_filter is not None
105+
else None,
97106
)
98107

99108
if mbox:
@@ -117,7 +126,7 @@ def apply_cmd(patch_id, series, deps, args):
117126
default=True,
118127
help='Show patch in mbox format.',
119128
)
120-
def download_cmd(patch_id, output, fmt):
129+
def download_cmd(patch_id: int, output: str | None, fmt: str) -> None:
121130
"""Download patch in diff or mbox format.
122131
123132
Download a patch but do not apply it. ``OUTPUT`` is optional and can be an
@@ -131,6 +140,7 @@ def download_cmd(patch_id, output, fmt):
131140

132141
if fmt == 'diff':
133142
if output and not os.path.isdir(output):
143+
output_path: int | str
134144
if output == '-':
135145
output_path = 0 # stdout fd
136146
else:
@@ -154,9 +164,9 @@ def download_cmd(patch_id, output, fmt):
154164
LOG.info('Downloaded patch to %s', path)
155165

156166

157-
def _show_patch(patch, fmt):
158-
def _format_series(series):
159-
return '%-4d %s' % (series.get('id'), series.get('name') or '-')
167+
def _show_patch(patch: dict[str, Any], fmt: str | None) -> None:
168+
def _format_series(series: dict[str, Any]) -> str:
169+
return '%-4d %s' % (series['id'], series.get('name') or '-')
160170

161171
output = [
162172
('ID', patch.get('id')),
@@ -167,17 +177,17 @@ def _format_series(series):
167177
(
168178
'Submitter',
169179
'{} ({})'.format(
170-
patch.get('submitter').get('name'),
171-
patch.get('submitter').get('email'),
180+
(patch.get('submitter') or {}).get('name'),
181+
(patch.get('submitter') or {}).get('email'),
172182
),
173183
),
174184
('State', patch.get('state')),
175185
('Archived', patch.get('archived')),
176-
('Project', patch.get('project').get('name')),
186+
('Project', (patch.get('project') or {}).get('name')),
177187
(
178188
'Delegate',
179189
(
180-
patch.get('delegate').get('username')
190+
(patch.get('delegate') or {}).get('username')
181191
if patch.get('delegate')
182192
else ''
183193
),
@@ -186,7 +196,7 @@ def _format_series(series):
186196
]
187197

188198
prefix = 'Series'
189-
for series in patch.get('series'):
199+
for series in patch.get('series') or []:
190200
output.append((prefix, _format_series(series)))
191201
prefix = ''
192202

@@ -196,7 +206,7 @@ def _format_series(series):
196206
@click.command(name='show')
197207
@utils.format_options
198208
@click.argument('patch_id', type=click.INT)
199-
def show_cmd(fmt, patch_id):
209+
def show_cmd(fmt: str | None, patch_id: int) -> None:
200210
"""Show information about patch.
201211
202212
Retrieve Patchwork metadata for a patch.
@@ -208,7 +218,7 @@ def show_cmd(fmt, patch_id):
208218
_show_patch(patch, fmt)
209219

210220

211-
def _get_states():
221+
def _get_states() -> list[str] | tuple[str, ...]:
212222
return CONF.states.split(',') if CONF.states else _default_states
213223

214224

@@ -244,7 +254,14 @@ def _get_states():
244254
help='Set the patch archived state.',
245255
)
246256
@utils.format_options
247-
def update_cmd(patch_ids, commit_ref, state, delegate, archived, fmt):
257+
def update_cmd(
258+
patch_ids: tuple[int, ...],
259+
commit_ref: str | None,
260+
state: str | None,
261+
delegate: str | None,
262+
archived: bool | None,
263+
fmt: str | None,
264+
) -> None:
248265
"""Update one or more patches.
249266
250267
Updates one or more Patches on the Patchwork instance. Some operations may
@@ -339,20 +356,20 @@ def update_cmd(patch_ids, commit_ref, state, delegate, archived, fmt):
339356
@click.argument('name', required=False)
340357
@api.validate_multiple_filter_support
341358
def list_cmd(
342-
states,
343-
submitters,
344-
delegates,
345-
hashes,
346-
archived,
347-
since,
348-
before,
349-
limit,
350-
page,
351-
sort,
352-
fmt,
353-
headers,
354-
name,
355-
):
359+
states: tuple[str, ...],
360+
submitters: tuple[str, ...],
361+
delegates: tuple[str, ...],
362+
hashes: tuple[str, ...],
363+
archived: bool,
364+
since: datetime.datetime | None,
365+
before: datetime.datetime | None,
366+
limit: int | None,
367+
page: int | None,
368+
sort: str,
369+
fmt: str | None,
370+
headers: tuple[str, ...],
371+
name: str | None,
372+
) -> None:
356373
"""List patches.
357374
358375
List patches on the Patchwork instance.
@@ -367,7 +384,7 @@ def list_cmd(
367384
archived,
368385
)
369386

370-
params = []
387+
params: list[tuple[str, str | int | None]] = []
371388

372389
for state in states:
373390
params.append(('state', state))

git_pw/series.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Series subcommands.
33
"""
44

5+
import datetime
56
import logging
67
import os.path
78
import sys
@@ -34,7 +35,7 @@
3435
'first.',
3536
)
3637
@click.argument('args', nargs=-1, type=click.UNPROCESSED)
37-
def apply_cmd(series_id, args, deps):
38+
def apply_cmd(series_id: int, args: tuple[str, ...], deps: bool) -> None:
3839
"""Apply series.
3940
4041
Apply a series locally using the 'git-am' command. Any additional ARGS
@@ -91,7 +92,7 @@ def apply_cmd(series_id, args, deps):
9192
default=True,
9293
help='Download all series patches to one file',
9394
)
94-
def download_cmd(series_id, output, fmt):
95+
def download_cmd(series_id: int, output: str | None, fmt: str) -> None:
9596
"""Download series in mbox format.
9697
9798
Download a series but do not apply it. ``OUTPUT`` is optional and can be an
@@ -131,7 +132,7 @@ def download_cmd(series_id, output, fmt):
131132
@click.command(name='show')
132133
@utils.format_options
133134
@click.argument('series_id', type=click.INT)
134-
def show_cmd(fmt, series_id):
135+
def show_cmd(fmt: str | None, series_id: int) -> None:
135136
"""Show information about series.
136137
137138
Retrieve Patchwork metadata for a series.
@@ -140,8 +141,8 @@ def show_cmd(fmt, series_id):
140141

141142
series = api.detail('series', series_id)
142143

143-
def _format_submission(submission):
144-
return '%-4d %s' % (submission.get('id'), submission.get('name'))
144+
def _format_submission(submission: dict[str, Any]) -> str:
145+
return '%-4d %s' % (submission['id'], submission.get('name'))
145146

146147
output = [
147148
('ID', series.get('id')),
@@ -165,7 +166,7 @@ def _format_submission(submission):
165166
(
166167
'Cover',
167168
(
168-
_format_submission(series.get('cover_letter'))
169+
_format_submission(series['cover_letter'])
169170
if series.get('cover_letter')
170171
else ''
171172
),
@@ -195,7 +196,17 @@ def _format_submission(submission):
195196
@utils.format_options(headers=_list_headers)
196197
@click.argument('name', required=False)
197198
@api.validate_multiple_filter_support
198-
def list_cmd(submitters, limit, page, sort, fmt, headers, name, since, before):
199+
def list_cmd(
200+
submitters: tuple[str, ...],
201+
limit: int | None,
202+
page: int | None,
203+
sort: str,
204+
fmt: str | None,
205+
headers: tuple[str, ...],
206+
name: str | None,
207+
since: datetime.datetime | None,
208+
before: datetime.datetime | None,
209+
) -> None:
199210
"""List series.
200211
201212
List series on the Patchwork instance.
@@ -208,7 +219,7 @@ def list_cmd(submitters, limit, page, sort, fmt, headers, name, since, before):
208219
sort,
209220
)
210221

211-
params = []
222+
params: list[tuple[str, str | int | None]] = []
212223

213224
for submitter in submitters:
214225
if submitter.isdigit():

0 commit comments

Comments
 (0)