diff --git a/pre_commit_hooks/end_of_file_fixer.py b/pre_commit_hooks/end_of_file_fixer.py index 1c07379..ad3ca3c 100644 --- a/pre_commit_hooks/end_of_file_fixer.py +++ b/pre_commit_hooks/end_of_file_fixer.py @@ -1,51 +1,43 @@ import argparse -import os -from typing import IO from typing import Optional from typing import Sequence -def fix_file(file_obj: IO[bytes]) -> int: +def _process_file(file_obj: bytes) -> bytes: # Test for newline at end of file # Empty files will throw IOError here - try: - file_obj.seek(-1, os.SEEK_END) - except OSError: - return 0 - last_character = file_obj.read(1) - # last_character will be '' for an empty file - if last_character not in {b'\n', b'\r'} and last_character != b'': - # Needs this seek for windows, otherwise IOError - file_obj.seek(0, os.SEEK_END) - file_obj.write(b'\n') - return 1 - while last_character in {b'\n', b'\r'}: - # Deal with the beginning of the file - if file_obj.tell() == 1: - # If we've reached the beginning of the file and it is all - # linebreaks then we can make this file empty - file_obj.seek(0) - file_obj.truncate() - return 1 + while len(file_obj): + if file_obj[-2:] == b'\r\n': + if len(file_obj) == 2: + return b'' + elif file_obj[-3:-2] not in {b'\n', b'\r'}: + return file_obj + else: + file_obj = file_obj[:-2] + elif file_obj[-1:] in {b'\n', b'\r'}: + if len(file_obj) == 1: + return b'' + elif file_obj[-2:-1] not in {b'\n', b'\r'}: + return file_obj + else: + file_obj = file_obj[:-1] + else: + return file_obj + b'\n' - # Go back two bytes and read a character - file_obj.seek(-2, os.SEEK_CUR) - last_character = file_obj.read(1) + return file_obj - # Our current position is at the end of the file just before any amount of - # newlines. If we find extraneous newlines, then backtrack and trim them. - position = file_obj.tell() - remaining = file_obj.read() - for sequence in (b'\n', b'\r\n', b'\r'): - if remaining == sequence: - return 0 - elif remaining.startswith(sequence): - file_obj.seek(position + len(sequence)) - file_obj.truncate() - return 1 - return 0 +def _fix_file(filename: str) -> bool: + with open(filename, mode='rb') as file_processed: + file_content = file_processed.read() + newcontent = _process_file(file_content) + if newcontent != file_content: + with open(filename, mode='wb') as file_processed: + file_processed.write(newcontent) + return True + else: + return False def main(argv: Optional[Sequence[str]] = None) -> int: @@ -57,11 +49,9 @@ def main(argv: Optional[Sequence[str]] = None) -> int: for filename in args.filenames: # Read as binary so we can read byte-by-byte - with open(filename, 'rb+') as file_obj: - ret_for_file = fix_file(file_obj) - if ret_for_file: - print(f'Fixing {filename}') - retv |= ret_for_file + if _fix_file(filename): + print(f'Fixing {filename}') + retv = 1 return retv diff --git a/tests/end_of_file_fixer_test.py b/tests/end_of_file_fixer_test.py index 60b9e82..2125c14 100644 --- a/tests/end_of_file_fixer_test.py +++ b/tests/end_of_file_fixer_test.py @@ -1,8 +1,6 @@ -import io - import pytest -from pre_commit_hooks.end_of_file_fixer import fix_file +from pre_commit_hooks.end_of_file_fixer import _process_file from pre_commit_hooks.end_of_file_fixer import main @@ -23,11 +21,9 @@ TESTS = ( @pytest.mark.parametrize(('input_s', 'expected_retval', 'output'), TESTS) -def test_fix_file(input_s, expected_retval, output): - file_obj = io.BytesIO(input_s) - ret = fix_file(file_obj) - assert file_obj.getvalue() == output - assert ret == expected_retval +def test_process_file(input_s, expected_retval, output): + processed = _process_file(input_s) + assert processed == output @pytest.mark.parametrize(('input_s', 'expected_retval', 'output'), TESTS)