diff --git a/pre_commit_hooks/end_of_file_fixer.py b/pre_commit_hooks/end_of_file_fixer.py index a88425c6..bb00fec1 100644 --- a/pre_commit_hooks/end_of_file_fixer.py +++ b/pre_commit_hooks/end_of_file_fixer.py @@ -6,7 +6,7 @@ from typing import IO -def fix_file(file_obj: IO[bytes]) -> int: +def fix_file(file_obj: IO[bytes], check_only=False) -> int: # Test for newline at end of file # Empty files will throw IOError here try: @@ -18,7 +18,8 @@ def fix_file(file_obj: IO[bytes]) -> int: if last_character not in {b'\n', b'\r'} and last_character != b'': # Needs this seek for windows, otherwise IOError file_obj.seek(0, os.SEEK_END) - file_obj.write(b'\n') + if not check_only: + file_obj.write(b'\n') return 1 while last_character in {b'\n', b'\r'}: @@ -27,7 +28,8 @@ def fix_file(file_obj: IO[bytes]) -> int: # If we've reached the beginning of the file and it is all # linebreaks then we can make this file empty file_obj.seek(0) - file_obj.truncate() + if not check_only: + file_obj.truncate() return 1 # Go back two bytes and read a character @@ -43,7 +45,8 @@ def fix_file(file_obj: IO[bytes]) -> int: return 0 elif remaining.startswith(sequence): file_obj.seek(position + len(sequence)) - file_obj.truncate() + if not check_only: + file_obj.truncate() return 1 return 0 @@ -51,6 +54,7 @@ def fix_file(file_obj: IO[bytes]) -> int: def main(argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser() + parser.add_argument('--check', action='store_true', help='Check without fixing') parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) @@ -59,11 +63,13 @@ def main(argv: Sequence[str] | None = None) -> int: for filename in args.filenames: # Read as binary so we can read byte-by-byte with open(filename, 'rb+') as file_obj: - ret_for_file = fix_file(file_obj) + ret_for_file = fix_file(file_obj, args.check) if ret_for_file: - print(f'Fixing {filename}') + if args.check: + print(f'Wrong ending of file: {filename}') + else: + print(f'Fixing {filename}') retv |= ret_for_file - return retv diff --git a/pre_commit_hooks/trailing_whitespace_fixer.py b/pre_commit_hooks/trailing_whitespace_fixer.py index dab8b14a..fc6dc861 100644 --- a/pre_commit_hooks/trailing_whitespace_fixer.py +++ b/pre_commit_hooks/trailing_whitespace_fixer.py @@ -9,14 +9,17 @@ def _fix_file( filename: str, is_markdown: bool, chars: bytes | None, + check_only: bool = False, + error_lines: list[int] = None, ) -> bool: with open(filename, mode='rb') as file_processed: lines = file_processed.readlines() - newlines = [_process_line(line, is_markdown, chars) for line in lines] + newlines = [_process_line(line, is_markdown, chars, line_num, error_lines) for line_num, line in enumerate(lines)] if newlines != lines: - with open(filename, mode='wb') as file_processed: - for line in newlines: - file_processed.write(line) + if not check_only: + with open(filename, mode='wb') as file_processed: + for line in newlines: + file_processed.write(line) return True else: return False @@ -26,7 +29,10 @@ def _process_line( line: bytes, is_markdown: bool, chars: bytes | None, + line_num: int, + error_lines: list[int] | None, ) -> bytes: + org_line = line if line[-2:] == b'\r\n': eol = b'\r\n' line = line[:-2] @@ -38,11 +44,15 @@ def _process_line( # preserve trailing two-space for non-blank lines in markdown files if is_markdown and (not line.isspace()) and line.endswith(b' '): return line[:-2].rstrip(chars) + b' ' + eol - return line.rstrip(chars) + eol + result = line.rstrip(chars) + eol + if error_lines is not None and org_line != result: + error_lines.append(line_num + 1) + return result def main(argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser() + parser.add_argument('--check', action='store_true', help='Check without fixing') parser.add_argument( '--no-markdown-linebreak-ext', action='store_true', @@ -93,8 +103,14 @@ def main(argv: Sequence[str] | None = None) -> int: for filename in args.filenames: _, extension = os.path.splitext(filename.lower()) md = all_markdown or extension in md_exts - if _fix_file(filename, md, chars): - print(f'Fixing {filename}') + error_lines = [] + if _fix_file(filename, md, chars, args.check, error_lines): + if args.check: + location = ','.join(map(str, error_lines[:4])) + location += '...' if len(error_lines) > 4 else '' + print(f'Trailing whitespace check failed: {filename} @ {location}') + else: + print(f'Fixing {filename}') return_code = 1 return return_code diff --git a/tests/end_of_file_fixer_test.py b/tests/end_of_file_fixer_test.py index 8a5d889e..db6ef351 100644 --- a/tests/end_of_file_fixer_test.py +++ b/tests/end_of_file_fixer_test.py @@ -10,34 +10,56 @@ # Input, expected return value, expected output TESTS = ( - (b'foo\n', 0, b'foo\n'), - (b'', 0, b''), - (b'\n\n', 1, b''), - (b'\n\n\n\n', 1, b''), - (b'foo', 1, b'foo\n'), - (b'foo\n\n\n', 1, b'foo\n'), - (b'\xe2\x98\x83', 1, b'\xe2\x98\x83\n'), - (b'foo\r\n', 0, b'foo\r\n'), - (b'foo\r\n\r\n\r\n', 1, b'foo\r\n'), - (b'foo\r', 0, b'foo\r'), - (b'foo\r\r\r\r', 1, b'foo\r'), + (b'foo\n', 0, b'foo\n', None), + (b'', 0, b'', None), + (b'\n\n', 1, b'', None), + (b'\n\n\n\n', 1, b'', None), + (b'foo', 1, b'foo\n', None), + (b'foo\n\n\n', 1, b'foo\n', None), + (b'\xe2\x98\x83', 1, b'\xe2\x98\x83\n', None), + (b'foo\r\n', 0, b'foo\r\n', None), + (b'foo\r\n\r\n\r\n', 1, b'foo\r\n', None), + (b'foo\r', 0, b'foo\r', None), + (b'foo\r\r\r\r', 1, b'foo\r', None), + + (b'foo\n', 0, b'foo\n', '--check'), + (b'', 0, b'', '--check'), + (b'\n\n', 1, b'\n\n', '--check'), + (b'\n\n\n\n', 1, b'\n\n\n\n', '--check'), + (b'foo', 1, b'foo', '--check'), + (b'foo\n\n\n', 1, b'foo\n\n\n', '--check'), + (b'\xe2\x98\x83', 1, b'\xe2\x98\x83', '--check'), + (b'foo\r\n', 0, b'foo\r\n', '--check'), + (b'foo\r\n\r\n\r\n', 1, b'foo\r\n\r\n\r\n', '--check'), + (b'foo\r', 0, b'foo\r', '--check'), + (b'foo\r\r\r\r', 1, b'foo\r\r\r\r', '--check'), ) -@pytest.mark.parametrize(('input_s', 'expected_retval', 'output'), TESTS) -def test_fix_file(input_s, expected_retval, output): +@pytest.mark.parametrize(('input_s', 'expected_retval', 'output', 'options'), TESTS) +def test_fix_file(input_s, expected_retval, output, options): + if options is None: + options = [] + elif isinstance(options, str): + options = [options] + file_obj = io.BytesIO(input_s) - ret = fix_file(file_obj) + ret = fix_file(file_obj, '--check' in [*options]) assert file_obj.getvalue() == output assert ret == expected_retval -@pytest.mark.parametrize(('input_s', 'expected_retval', 'output'), TESTS) -def test_integration(input_s, expected_retval, output, tmpdir): +@pytest.mark.parametrize(('input_s', 'expected_retval', 'output', 'options'), TESTS) +def test_integration(input_s, expected_retval, output, options, tmpdir): path = tmpdir.join('file.txt') path.write_binary(input_s) - ret = main([str(path)]) + if options is None: + options = [] + elif isinstance(options, str): + options = [options] + + ret = main([*options, str(path)]) file_output = path.read_binary() assert file_output == output diff --git a/tests/trailing_whitespace_fixer_test.py b/tests/trailing_whitespace_fixer_test.py index c07497a2..949d4206 100644 --- a/tests/trailing_whitespace_fixer_test.py +++ b/tests/trailing_whitespace_fixer_test.py @@ -19,6 +19,24 @@ def test_fixes_trailing_whitespace(input_s, expected, tmpdir): assert path.read() == expected +@pytest.mark.parametrize( + ('input_s', 'exit_code', 'lines'), + ( + ('foo \nbar \n', 1, [1, 2]), + ('bar\t\nbaz\t\n', 1, [1, 2]), + ('bar\nbaz\t\n', 1, [2]), + ), +) +def test_fixes_trailing_whitespace_check_only(capsys, input_s, exit_code, lines, tmpdir): + path = tmpdir.join('file.md') + path.write(input_s) + assert main(('--check', str(path))) == exit_code + assert path.read() == input_s + captured = capsys.readouterr() + location = '@ ' + ','.join(map(str, lines)) + assert location in captured.out + + def test_ok_no_newline_end_of_file(tmpdir): filename = tmpdir.join('f') filename.write_binary(b'foo\nbar')