Merge pull request #421 from iconmaster5326/master

trailing-whitespace: add option for custom chars to strip
This commit is contained in:
Anthony Sottile 2019-10-25 09:38:13 -07:00 committed by GitHub
commit d61d4a26db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 8 deletions

View file

@ -111,6 +111,8 @@ Add this to your `.pre-commit-config.yaml`
use `args: [--markdown-linebreak-ext=md]` (or other extensions used use `args: [--markdown-linebreak-ext=md]` (or other extensions used
by your markdownfiles). If for some reason you want to treat all files by your markdownfiles). If for some reason you want to treat all files
as markdown, use `--markdown-linebreak-ext=*`. as markdown, use `--markdown-linebreak-ext=*`.
- By default, this hook trims all whitespace from the ends of lines.
To specify a custom set of characters to trim instead, use `args: [--chars,"<chars to trim>"]`.
### Deprecated / replaced hooks ### Deprecated / replaced hooks

View file

@ -7,10 +7,11 @@ from typing import Optional
from typing import Sequence from typing import Sequence
def _fix_file(filename, is_markdown): # type: (str, bool) -> bool def _fix_file(filename, is_markdown, chars):
# type: (str, bool, Optional[bytes]) -> bool
with open(filename, mode='rb') as file_processed: with open(filename, mode='rb') as file_processed:
lines = file_processed.readlines() lines = file_processed.readlines()
newlines = [_process_line(line, is_markdown) for line in lines] newlines = [_process_line(line, is_markdown, chars) for line in lines]
if newlines != lines: if newlines != lines:
with open(filename, mode='wb') as file_processed: with open(filename, mode='wb') as file_processed:
for line in newlines: for line in newlines:
@ -20,17 +21,20 @@ def _fix_file(filename, is_markdown): # type: (str, bool) -> bool
return False return False
def _process_line(line, is_markdown): # type: (bytes, bool) -> bytes def _process_line(line, is_markdown, chars):
# type: (bytes, bool, Optional[bytes]) -> bytes
if line[-2:] == b'\r\n': if line[-2:] == b'\r\n':
eol = b'\r\n' eol = b'\r\n'
line = line[:-2]
elif line[-1:] == b'\n': elif line[-1:] == b'\n':
eol = b'\n' eol = b'\n'
line = line[:-1]
else: else:
eol = b'' eol = b''
# preserve trailing two-space for non-blank lines in markdown files # preserve trailing two-space for non-blank lines in markdown files
if is_markdown and (not line.isspace()) and line.endswith(b' ' + eol): if is_markdown and (not line.isspace()) and line.endswith(b' '):
return line.rstrip() + b' ' + eol return line[:-2].rstrip(chars) + b' ' + eol
return line.rstrip() + eol return line.rstrip(chars) + eol
def main(argv=None): # type: (Optional[Sequence[str]]) -> int def main(argv=None): # type: (Optional[Sequence[str]]) -> int
@ -50,6 +54,13 @@ def main(argv=None): # type: (Optional[Sequence[str]]) -> int
'default: %(default)s' 'default: %(default)s'
), ),
) )
parser.add_argument(
'--chars',
help=(
'The set of characters to strip from the end of lines. '
'Defaults to all whitespace characters.'
),
)
parser.add_argument('filenames', nargs='*', help='Filenames to fix') parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -73,12 +84,12 @@ def main(argv=None): # type: (Optional[Sequence[str]]) -> int
" (probably filename; use '--markdown-linebreak-ext=EXT')" " (probably filename; use '--markdown-linebreak-ext=EXT')"
.format(ext), .format(ext),
) )
chars = None if args.chars is None else args.chars.encode('utf-8')
return_code = 0 return_code = 0
for filename in args.filenames: for filename in args.filenames:
_, extension = os.path.splitext(filename.lower()) _, extension = os.path.splitext(filename.lower())
md = all_markdown or extension in md_exts md = all_markdown or extension in md_exts
if _fix_file(filename, md): if _fix_file(filename, md, chars):
print('Fixing {}'.format(filename)) print('Fixing {}'.format(filename))
return_code = 1 return_code = 1
return return_code return return_code

View file

@ -78,3 +78,27 @@ def test_preserve_non_utf8_file(tmpdir):
ret = main([path.strpath]) ret = main([path.strpath])
assert ret == 1 assert ret == 1
assert path.size() == (len(non_utf8_bytes_content) - 1) assert path.size() == (len(non_utf8_bytes_content) - 1)
def test_custom_charset_change(tmpdir):
# strip spaces only, no tabs
path = tmpdir.join('file.txt')
path.write('\ta \t \n')
ret = main([path.strpath, '--chars', ' '])
assert ret == 1
assert path.read() == '\ta \t\n'
def test_custom_charset_no_change(tmpdir):
path = tmpdir.join('file.txt')
path.write('\ta \t\n')
ret = main([path.strpath, '--chars', ' '])
assert ret == 0
def test_markdown_with_custom_charset(tmpdir):
path = tmpdir.join('file.md')
path.write('\ta \t \n')
ret = main([path.strpath, '--chars', ' ', '--markdown-linebreak-ext', '*'])
assert ret == 1
assert path.read() == '\ta \t \n'