diff --git a/pre_commit_hooks/trailing_whitespace_fixer.py b/pre_commit_hooks/trailing_whitespace_fixer.py index c159071..fa9b7dd 100644 --- a/pre_commit_hooks/trailing_whitespace_fixer.py +++ b/pre_commit_hooks/trailing_whitespace_fixer.py @@ -1,24 +1,27 @@ from __future__ import print_function import argparse -import fileinput import os import sys from pre_commit_hooks.util import cmd_output -def _fix_file(filename, markdown=False): - for line in fileinput.input([filename], inplace=True): - # preserve trailing two-space for non-blank lines in markdown files - if markdown and (not line.isspace()) and (line.endswith(" \n")): - line = line.rstrip(' \n') - # only preserve if there are no trailing tabs or unusual whitespace - if not line[-1].isspace(): - print(line + " ") - continue +def _fix_file(filename, is_markdown): + with open(filename, mode='rb') as file_processed: + lines = file_processed.readlines() + lines = [_process_line(line, is_markdown) for line in lines] + with open(filename, mode='wb') as file_processed: + for line in lines: + file_processed.write(line) - print(line.rstrip()) + +def _process_line(line, is_markdown): + # preserve trailing two-space for non-blank lines in markdown files + eol = b'\r\n' if line[-2:] == b'\r\n' else b'\n' + if is_markdown and (not line.isspace()) and line.endswith(b' ' + eol): + return line.rstrip() + b' ' + eol + return line.rstrip() + eol def fix_trailing_whitespace(argv=None): @@ -64,14 +67,13 @@ def fix_trailing_whitespace(argv=None): .format(ext) ) - if bad_whitespace_files: - for bad_whitespace_file in bad_whitespace_files: - print('Fixing {0}'.format(bad_whitespace_file)) - _, extension = os.path.splitext(bad_whitespace_file.lower()) - _fix_file(bad_whitespace_file, all_markdown or extension in md_exts) - return 1 - else: - return 0 + return_code = 0 + for bad_whitespace_file in bad_whitespace_files: + print('Fixing {0}'.format(bad_whitespace_file)) + _, extension = os.path.splitext(bad_whitespace_file.lower()) + _fix_file(bad_whitespace_file, all_markdown or extension in md_exts) + return_code = 1 + return return_code if __name__ == '__main__': diff --git a/tests/trailing_whitespace_fixer_test.py b/tests/trailing_whitespace_fixer_test.py index 6f4fdfd..3a72ccb 100644 --- a/tests/trailing_whitespace_fixer_test.py +++ b/tests/trailing_whitespace_fixer_test.py @@ -42,7 +42,7 @@ def test_fixes_trailing_markdown_whitespace(filename, input_s, output, tmpdir): MD_TESTS_2 = ( ('foo.txt', 'foo \nbar \n \n', 'foo \nbar\n\n'), ('bar.Markdown', 'bar \nbaz\t\n\t\n', 'bar \nbaz\n\n'), - ('bar.MD', 'bar \nbaz\t \n\t\n', 'bar \nbaz\n\n'), + ('bar.MD', 'bar \nbaz\t \n\t\n', 'bar \nbaz \n\n'), ('.txt', 'baz \nquux \t\n\t\n', 'baz\nquux\n\n'), ('txt', 'foo \nbaz \n\t\n', 'foo\nbaz\n\n'), ) @@ -103,3 +103,12 @@ def test_no_markdown_linebreak_ext_opt(filename, input_s, output, tmpdir): def test_returns_zero_for_no_changes(): assert fix_trailing_whitespace([__file__]) == 0 + + +def test_preserve_non_utf8_file(tmpdir): + non_utf8_bytes_content = b'\xe9 \n\n' + path = tmpdir.join('file.txt') + path.write_binary(non_utf8_bytes_content) + ret = fix_trailing_whitespace([path.strpath]) + assert ret == 1 + assert path.size() == (len(non_utf8_bytes_content) - 1)