forked from pre-commit/pre-commit-hooks
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrequirements_txt_fixer.py
More file actions
190 lines (151 loc) · 5.61 KB
/
requirements_txt_fixer.py
File metadata and controls
190 lines (151 loc) · 5.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from __future__ import annotations
import argparse
import re
from collections.abc import Sequence
from typing import IO
PASS = 0
FAIL = 1
class Requirement:
UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?')
UNTIL_SEP = re.compile(rb'[^;\s]+')
VERSION_SPECIFIED = re.compile(b'.+(={2,3}|!=|~=|>=?|<=?).+')
def __init__(self) -> None:
self.value: bytes | None = None
self.comments: list[bytes] = []
@property
def name(self) -> bytes:
assert self.value is not None, self.value
name = self.value.lower()
for egg in (b'#egg=', b'&egg='):
if egg in self.value:
return name.partition(egg)[-1]
m = self.UNTIL_SEP.match(name)
assert m is not None
name = m.group()
m = self.UNTIL_COMPARISON.search(name)
if not m:
return name
return name[:m.start()]
def __lt__(self, requirement: Requirement) -> bool:
# \n means top of file comment, so always return True,
# otherwise just do a string comparison with value.
assert self.value is not None, self.value
if self.value == b'\n':
return True
elif requirement.value == b'\n':
return False
else:
# if 2 requirements have the same name, the one with comments
# needs to go first (so that when removing duplicates, the one
# with comments is kept)
if self.name == requirement.name:
return bool(self.comments) > bool(requirement.comments)
return self.name < requirement.name
def is_complete(self) -> bool:
return (
self.value is not None and
not self.value.rstrip(b'\r\n').endswith(b'\\')
)
def contains_version_specifier(self) -> bool:
return (
self.value is not None and
bool(self.VERSION_SPECIFIED.match(self.value))
)
def append_value(self, value: bytes) -> None:
if self.value is not None:
self.value += value
else:
self.value = value
def fix_requirements(f: IO[bytes], fail_without_version: bool) -> int:
requirements: list[Requirement] = []
before = list(f)
after: list[bytes] = []
before_string = b''.join(before)
# adds new line in case one is missing
# AND a change to the requirements file is needed regardless:
if before and not before[-1].endswith(b'\n'):
before[-1] += b'\n'
# If the file is empty (i.e. only whitespace/newlines) exit early
if before_string.strip() == b'':
return PASS
for line in before:
# If the most recent requirement object has a value, then it's
# time to start building the next requirement object.
if not len(requirements) or requirements[-1].is_complete():
requirements.append(Requirement())
requirement = requirements[-1]
# If we see a newline before any requirements, then this is a
# top of file comment.
if len(requirements) == 1 and line.strip() == b'':
if (
len(requirement.comments) and
requirement.comments[0].startswith(b'#')
):
requirement.value = b'\n'
else:
requirement.comments.append(line)
elif line.lstrip().startswith(b'#') or line.strip() == b'':
requirement.comments.append(line)
else:
requirement.append_value(line)
# if a file ends in a comment, preserve it at the end
if requirements[-1].value is None:
rest = requirements.pop().comments
else:
rest = []
# find and remove pkg-resources==0.0.0
# which is automatically added by broken pip package under Debian
requirements = [
req for req in requirements
if req.value not in [
b'pkg-resources==0.0.0\n',
b'pkg_resources==0.0.0\n',
]
]
# check for requirements without a version specified
if fail_without_version:
missing_requirement_found = False
for req in requirements:
if not req.contains_version_specifier():
print(f'Missing version for requirement {req.name.decode()}')
missing_requirement_found = True
if missing_requirement_found:
return FAIL
# sort the requirements and remove duplicates
prev = None
for requirement in sorted(requirements):
after.extend(requirement.comments)
assert requirement.value, requirement.value
if prev is None or requirement.value != prev.value:
after.append(requirement.value)
prev = requirement
after.extend(rest)
after_string = b''.join(after)
if before_string == after_string:
return PASS
else:
print('Sorting requirements')
f.seek(0)
f.write(after_string)
f.truncate()
return FAIL
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
parser.add_argument(
'--fail-without-version', action='store_true',
help='Fail if a requirement is missing a version',
)
args = parser.parse_args(argv)
retv = PASS
for arg in args.filenames:
with open(arg, 'rb+') as file_obj:
ret_for_file = fix_requirements(
file_obj, args.fail_without_version,
)
if ret_for_file:
print(f'Error in file {arg}')
retv |= ret_for_file
return retv
if __name__ == '__main__':
raise SystemExit(main())