Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions pre_commit_hooks/end_of_file_fixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import IO


def fix_file(file_obj: IO[bytes]) -> int:
def fix_file(file_obj: IO[bytes], check_only=False) -> int:
# Test for newline at end of file
# Empty files will throw IOError here
try:
Expand All @@ -18,7 +18,8 @@ def fix_file(file_obj: IO[bytes]) -> int:
if last_character not in {b'\n', b'\r'} and last_character != b'':
# Needs this seek for windows, otherwise IOError
file_obj.seek(0, os.SEEK_END)
file_obj.write(b'\n')
if not check_only:
file_obj.write(b'\n')
return 1

while last_character in {b'\n', b'\r'}:
Expand All @@ -27,7 +28,8 @@ def fix_file(file_obj: IO[bytes]) -> int:
# If we've reached the beginning of the file and it is all
# linebreaks then we can make this file empty
file_obj.seek(0)
file_obj.truncate()
if not check_only:
file_obj.truncate()
return 1

# Go back two bytes and read a character
Expand All @@ -43,14 +45,16 @@ def fix_file(file_obj: IO[bytes]) -> int:
return 0
elif remaining.startswith(sequence):
file_obj.seek(position + len(sequence))
file_obj.truncate()
if not check_only:
file_obj.truncate()
return 1

return 0


def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('--check', action='store_true', help='Check without fixing')
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)

Expand All @@ -59,11 +63,13 @@ def main(argv: Sequence[str] | None = None) -> int:
for filename in args.filenames:
# Read as binary so we can read byte-by-byte
with open(filename, 'rb+') as file_obj:
ret_for_file = fix_file(file_obj)
ret_for_file = fix_file(file_obj, args.check)
if ret_for_file:
print(f'Fixing {filename}')
if args.check:
print(f'Wrong ending of file: {filename}')
else:
print(f'Fixing {filename}')
retv |= ret_for_file

return retv


Expand Down
30 changes: 23 additions & 7 deletions pre_commit_hooks/trailing_whitespace_fixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ def _fix_file(
filename: str,
is_markdown: bool,
chars: bytes | None,
check_only: bool = False,
error_lines: list[int] = None,
) -> bool:
with open(filename, mode='rb') as file_processed:
lines = file_processed.readlines()
newlines = [_process_line(line, is_markdown, chars) for line in lines]
newlines = [_process_line(line, is_markdown, chars, line_num, error_lines) for line_num, line in enumerate(lines)]
if newlines != lines:
with open(filename, mode='wb') as file_processed:
for line in newlines:
file_processed.write(line)
if not check_only:
with open(filename, mode='wb') as file_processed:
for line in newlines:
file_processed.write(line)
return True
else:
return False
Expand All @@ -26,7 +29,10 @@ def _process_line(
line: bytes,
is_markdown: bool,
chars: bytes | None,
line_num: int,
error_lines: list[int] | None,
) -> bytes:
org_line = line
if line[-2:] == b'\r\n':
eol = b'\r\n'
line = line[:-2]
Expand All @@ -38,11 +44,15 @@ def _process_line(
# preserve trailing two-space for non-blank lines in markdown files
if is_markdown and (not line.isspace()) and line.endswith(b' '):
return line[:-2].rstrip(chars) + b' ' + eol
return line.rstrip(chars) + eol
result = line.rstrip(chars) + eol
if error_lines is not None and org_line != result:
error_lines.append(line_num + 1)
return result


def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('--check', action='store_true', help='Check without fixing')
parser.add_argument(
'--no-markdown-linebreak-ext',
action='store_true',
Expand Down Expand Up @@ -93,8 +103,14 @@ def main(argv: Sequence[str] | None = None) -> int:
for filename in args.filenames:
_, extension = os.path.splitext(filename.lower())
md = all_markdown or extension in md_exts
if _fix_file(filename, md, chars):
print(f'Fixing {filename}')
error_lines = []
if _fix_file(filename, md, chars, args.check, error_lines):
if args.check:
location = ','.join(map(str, error_lines[:4]))
location += '...' if len(error_lines) > 4 else ''
print(f'Trailing whitespace check failed: {filename} @ {location}')
else:
print(f'Fixing {filename}')
return_code = 1
return return_code

Expand Down
56 changes: 39 additions & 17 deletions tests/end_of_file_fixer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,56 @@

# Input, expected return value, expected output
TESTS = (
(b'foo\n', 0, b'foo\n'),
(b'', 0, b''),
(b'\n\n', 1, b''),
(b'\n\n\n\n', 1, b''),
(b'foo', 1, b'foo\n'),
(b'foo\n\n\n', 1, b'foo\n'),
(b'\xe2\x98\x83', 1, b'\xe2\x98\x83\n'),
(b'foo\r\n', 0, b'foo\r\n'),
(b'foo\r\n\r\n\r\n', 1, b'foo\r\n'),
(b'foo\r', 0, b'foo\r'),
(b'foo\r\r\r\r', 1, b'foo\r'),
(b'foo\n', 0, b'foo\n', None),
(b'', 0, b'', None),
(b'\n\n', 1, b'', None),
(b'\n\n\n\n', 1, b'', None),
(b'foo', 1, b'foo\n', None),
(b'foo\n\n\n', 1, b'foo\n', None),
(b'\xe2\x98\x83', 1, b'\xe2\x98\x83\n', None),
(b'foo\r\n', 0, b'foo\r\n', None),
(b'foo\r\n\r\n\r\n', 1, b'foo\r\n', None),
(b'foo\r', 0, b'foo\r', None),
(b'foo\r\r\r\r', 1, b'foo\r', None),

(b'foo\n', 0, b'foo\n', '--check'),
(b'', 0, b'', '--check'),
(b'\n\n', 1, b'\n\n', '--check'),
(b'\n\n\n\n', 1, b'\n\n\n\n', '--check'),
(b'foo', 1, b'foo', '--check'),
(b'foo\n\n\n', 1, b'foo\n\n\n', '--check'),
(b'\xe2\x98\x83', 1, b'\xe2\x98\x83', '--check'),
(b'foo\r\n', 0, b'foo\r\n', '--check'),
(b'foo\r\n\r\n\r\n', 1, b'foo\r\n\r\n\r\n', '--check'),
(b'foo\r', 0, b'foo\r', '--check'),
(b'foo\r\r\r\r', 1, b'foo\r\r\r\r', '--check'),
)


@pytest.mark.parametrize(('input_s', 'expected_retval', 'output'), TESTS)
def test_fix_file(input_s, expected_retval, output):
@pytest.mark.parametrize(('input_s', 'expected_retval', 'output', 'options'), TESTS)
def test_fix_file(input_s, expected_retval, output, options):
if options is None:
options = []
elif isinstance(options, str):
options = [options]

file_obj = io.BytesIO(input_s)
ret = fix_file(file_obj)
ret = fix_file(file_obj, '--check' in [*options])
assert file_obj.getvalue() == output
assert ret == expected_retval


@pytest.mark.parametrize(('input_s', 'expected_retval', 'output'), TESTS)
def test_integration(input_s, expected_retval, output, tmpdir):
@pytest.mark.parametrize(('input_s', 'expected_retval', 'output', 'options'), TESTS)
def test_integration(input_s, expected_retval, output, options, tmpdir):
path = tmpdir.join('file.txt')
path.write_binary(input_s)

ret = main([str(path)])
if options is None:
options = []
elif isinstance(options, str):
options = [options]

ret = main([*options, str(path)])
file_output = path.read_binary()

assert file_output == output
Expand Down
18 changes: 18 additions & 0 deletions tests/trailing_whitespace_fixer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@ def test_fixes_trailing_whitespace(input_s, expected, tmpdir):
assert path.read() == expected


@pytest.mark.parametrize(
('input_s', 'exit_code', 'lines'),
(
('foo \nbar \n', 1, [1, 2]),
('bar\t\nbaz\t\n', 1, [1, 2]),
('bar\nbaz\t\n', 1, [2]),
),
)
def test_fixes_trailing_whitespace_check_only(capsys, input_s, exit_code, lines, tmpdir):
path = tmpdir.join('file.md')
path.write(input_s)
assert main(('--check', str(path))) == exit_code
assert path.read() == input_s
captured = capsys.readouterr()
location = '@ ' + ','.join(map(str, lines))
assert location in captured.out


def test_ok_no_newline_end_of_file(tmpdir):
filename = tmpdir.join('f')
filename.write_binary(b'foo\nbar')
Expand Down
Loading