Add "--check" option support to trailing_whitespace hook

This commit is contained in:
zhiwei.meng 2026-04-02 15:20:54 +08:00
parent 3b8b26a097
commit 4842364696
2 changed files with 40 additions and 7 deletions

View file

@ -9,14 +9,17 @@ def _fix_file(
filename: str, filename: str,
is_markdown: bool, is_markdown: bool,
chars: bytes | None, chars: bytes | None,
check_only: bool = False,
error_lines: list[int] = None,
) -> bool: ) -> bool:
with open(filename, mode='rb') as file_processed: with open(filename, mode='rb') as file_processed:
lines = file_processed.readlines() lines = file_processed.readlines()
newlines = [_process_line(line, is_markdown, chars) for line in lines] newlines = [_process_line(line, is_markdown, chars, line_num, error_lines) for line_num, line in enumerate(lines)]
if newlines != lines: if newlines != lines:
with open(filename, mode='wb') as file_processed: if not check_only:
for line in newlines: with open(filename, mode='wb') as file_processed:
file_processed.write(line) for line in newlines:
file_processed.write(line)
return True return True
else: else:
return False return False
@ -26,7 +29,10 @@ def _process_line(
line: bytes, line: bytes,
is_markdown: bool, is_markdown: bool,
chars: bytes | None, chars: bytes | None,
line_num: int,
error_lines: list[int] | None
) -> bytes: ) -> bytes:
org_line = line
if line[-2:] == b'\r\n': if line[-2:] == b'\r\n':
eol = b'\r\n' eol = b'\r\n'
line = line[:-2] line = line[:-2]
@ -38,11 +44,15 @@ def _process_line(
# preserve trailing two-space for non-blank lines in markdown files # preserve trailing two-space for non-blank lines in markdown files
if is_markdown and (not line.isspace()) and line.endswith(b' '): if is_markdown and (not line.isspace()) and line.endswith(b' '):
return line[:-2].rstrip(chars) + b' ' + eol return line[:-2].rstrip(chars) + b' ' + eol
return line.rstrip(chars) + eol result = line.rstrip(chars) + eol
if error_lines is not None and org_line != result:
error_lines.append(line_num+1)
return result
def main(argv: Sequence[str] | None = None) -> int: def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--check', action='store_true', help='Check without fixing')
parser.add_argument( parser.add_argument(
'--no-markdown-linebreak-ext', '--no-markdown-linebreak-ext',
action='store_true', action='store_true',
@ -93,8 +103,14 @@ def main(argv: Sequence[str] | None = None) -> int:
for filename in args.filenames: for filename in args.filenames:
_, extension = os.path.splitext(filename.lower()) _, extension = os.path.splitext(filename.lower())
md = all_markdown or extension in md_exts md = all_markdown or extension in md_exts
if _fix_file(filename, md, chars): error_lines = []
print(f'Fixing {filename}') if _fix_file(filename, md, chars, args.check, error_lines):
if args.check:
location = ",".join(map(str, error_lines[:4]))
location += "..." if len(error_lines) > 4 else ""
print(f'Trailing whitespace check failed: {filename} @ {location}')
else:
print(f'Fixing {filename}')
return_code = 1 return_code = 1
return return_code return return_code

View file

@ -18,6 +18,23 @@ def test_fixes_trailing_whitespace(input_s, expected, tmpdir):
assert main((str(path),)) == 1 assert main((str(path),)) == 1
assert path.read() == expected assert path.read() == expected
@pytest.mark.parametrize(
('input_s', 'exit_code', "lines"),
(
('foo \nbar \n', 1, [1,2]),
('bar\t\nbaz\t\n', 1, [1,2]),
('bar\nbaz\t\n', 1, [2]),
),
)
def test_fixes_trailing_whitespace_check_only(capsys, input_s, exit_code, lines, tmpdir):
path = tmpdir.join('file.md')
path.write(input_s)
assert main(('--check', str(path),)) == exit_code
assert path.read() == input_s
captured = capsys.readouterr()
location = "@ " + ','.join(map(str, lines))
assert location in captured.out
def test_ok_no_newline_end_of_file(tmpdir): def test_ok_no_newline_end_of_file(tmpdir):
filename = tmpdir.join('f') filename = tmpdir.join('f')