diff --git a/pre_commit_hooks/file_contents_sorter.py b/pre_commit_hooks/file_contents_sorter.py index 02bdbcc..97ee45b 100644 --- a/pre_commit_hooks/file_contents_sorter.py +++ b/pre_commit_hooks/file_contents_sorter.py @@ -22,6 +22,55 @@ PASS = 0 FAIL = 1 +class Line: + """Wrapper to ignore end-of-line characters for sorting and comparison""" + + def __init__(self, value: bytes, eol: bytes): + self._value = value + # Add an EOL if none present (can only happen to the last line) + if not self._value.endswith(b'\n'): + self._value += eol + + def without_eol(self) -> bytes: + return self._value.rstrip(b'\n\r') + + def unwrap(self) -> bytes: + return self._value + + @classmethod + def key( + cls, + key: Callable[[bytes], Any] | None = None, + ) -> Callable[[Line], Any]: + if key is None: + return cls.without_eol + else: + def eol_key(val: Line) -> Any: + return key(val.without_eol()) + + return eol_key + + def __eq__(self, o: object) -> bool: + if not isinstance(o, Line): + return NotImplemented + return self.without_eol().__eq__(o.without_eol()) + + def __hash__(self) -> int: + return self.without_eol().__hash__() + + +def guess_eol(lines: list[bytes]) -> bytes: + if len(lines) == 0: + return b'\n' + + for eol in [b'\r\n', b'\n']: + if lines[0].endswith(eol): + return eol + + # Prefer '\n' if the first (only) line does not have a line ending + return b'\n' + + def sort_file_contents( f: IO[bytes], key: Callable[[bytes], Any] | None, @@ -29,18 +78,16 @@ def sort_file_contents( unique: bool = False, ) -> int: before = list(f) - lines: Iterable[bytes] = ( - line.rstrip(b'\n\r') for line in before if line.strip() + eol = guess_eol(before) + lines: Iterable[Line] = ( + Line(line, eol) for line in before if line.strip() ) if unique: lines = set(lines) - after = sorted(lines, key=key) + after = sorted(lines, key=Line.key(key)) before_string = b''.join(before) - after_string = b'\n'.join(after) - - if after_string: - after_string += b'\n' + after_string = b''.join(line.unwrap() for line in after) if before_string == after_string: return PASS diff --git a/tests/file_contents_sorter_test.py b/tests/file_contents_sorter_test.py index 49b3b79..6d232fc 100644 --- a/tests/file_contents_sorter_test.py +++ b/tests/file_contents_sorter_test.py @@ -17,6 +17,7 @@ from pre_commit_hooks.file_contents_sorter import PASS (b'missing_newline', [], FAIL, b'missing_newline\n'), (b'newline\nmissing', [], FAIL, b'missing\nnewline\n'), (b'missing\nnewline', [], FAIL, b'missing\nnewline\n'), + (b'missing\r\nnewline', [], FAIL, b'missing\r\nnewline\r\n'), (b'alpha\nbeta\n', [], PASS, b'alpha\nbeta\n'), (b'beta\nalpha\n', [], FAIL, b'alpha\nbeta\n'), (b'C\nc\n', [], PASS, b'C\nc\n'), @@ -67,6 +68,12 @@ from pre_commit_hooks.file_contents_sorter import PASS FAIL, b'Fie\nFoe\nfee\nfum\n', ), + ( + b'Fie\r\nFie\nFoe\nfee\nfee\r\nfum\n', + ['--unique'], + FAIL, + b'Fie\r\nFoe\nfee\nfum\n', + ), ( b'fee\nFie\nFoe\nfum\n', ['--unique', '--ignore-case'], @@ -79,6 +86,24 @@ from pre_commit_hooks.file_contents_sorter import PASS FAIL, b'fee\nFie\nFoe\nfum\n', ), + ( + b'linefeed\r\ncarriage_return\r\n', + [], + FAIL, + b'carriage_return\r\nlinefeed\r\n', + ), + ( + b'carriage_return\r\nlinefeed\r\n', + [], + PASS, + b'carriage_return\r\nlinefeed\r\n', + ), + ( + b'a\na\r\na\r\na\na\r\na\n', + [], + PASS, + b'a\na\r\na\r\na\na\r\na\n', + ), ), ) def test_integration(input_s, argv, expected_retval, output, tmpdir):