diff --git a/README.md b/README.md index a2d38a7..98e3ac7 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ Add this to your `.pre-commit-config.yaml` - `flake8` - Run flake8 on your python files - `name-tests-test` - Assert that files in tests/ end in _test.py - `pyflakes` - Run pyflakes on your python files +- `requirements-txt-fixer` - Sorts entries in requirements.txt - `trailing-whitespace` - Trims trailing whitespace. ### As a standalone package diff --git a/hooks.yaml b/hooks.yaml index 3fa9bd9..31b3cd0 100644 --- a/hooks.yaml +++ b/hooks.yaml @@ -47,6 +47,12 @@ entry: pyflakes language: python files: \.py$ +- id: requirements-txt-fixer + name: Fix requirements.txt + description: Sorts entries in requirements.txt + entry: requirements-txt-fixer + language: python + files: requirements.*\.txt$ - id: trailing-whitespace name: Trim Trailing Whitespace description: This hook trims trailing whitespace. diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py new file mode 100644 index 0000000..22fdb52 --- /dev/null +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -0,0 +1,82 @@ +from __future__ import print_function + +import argparse + +from pre_commit_hooks.util import entry + + +class Requirement(object): + + def __init__(self): + super(Requirement, self).__init__() + self.value = None + self.comments = [] + + def __lt__(self, requirement): + # \n means top of file comment, so always return True, + # otherwise just do a string comparison with value. + if self.value == b'\n': + return True + elif requirement.value == b'\n': + return False + else: + return self.value < requirement.value + + +def fix_requirements(f): + requirements = [] + before = [] + after = [] + + for line in f: + before.append(line) + + # If the most recent requirement object has a value, then it's time to + # start building the next requirement object. + if not len(requirements) or requirements[-1].value is not None: + requirements.append(Requirement()) + + requirement = requirements[-1] + + # If we see a newline before any requirements, then this is a top of + # file comment. + if len(requirements) == 1 and line.strip() == b'': + if len(requirement.comments) and requirement.comments[0].startswith(b'#'): + requirement.value = b'\n' + else: + requirement.comments.append(line) + elif line.startswith(b'#') or line.strip() == b'': + requirement.comments.append(line) + else: + requirement.value = line + + for requirement in sorted(requirements): + for comment in requirement.comments: + after.append(comment) + after.append(requirement.value) + + before_string = b''.join(before) + after_string = b''.join(after) + + if before_string == after_string: + return 0 + else: + f.seek(0) + f.write(after_string) + f.truncate() + return 1 + + +@entry +def fix_requirements_txt(argv): + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + retv = 0 + + for arg in args.filenames: + with open(arg, 'rb+') as f: + retv |= fix_requirements(f) + + return retv diff --git a/requirements-dev.txt b/requirements-dev.txt index 42ddf61..ae68bdc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,6 +3,6 @@ coverage flake8 mock -pylint +pylint<1.4 pytest pre-commit diff --git a/setup.py b/setup.py index d7f1baa..c801a63 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ setup( 'debug-statement-hook = pre_commit_hooks.debug_statement_hook:debug_statement_hook', 'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:end_of_file_fixer', 'name-tests-test = pre_commit_hooks.tests_should_end_in_test:validate_files', + 'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:fix_requirements', 'trailing-whitespace-fixer = pre_commit_hooks.trailing_whitespace_fixer:fix_trailing_whitespace', ], }, diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py new file mode 100644 index 0000000..49b0bb7 --- /dev/null +++ b/tests/requirements_txt_fixer_test.py @@ -0,0 +1,47 @@ +import os.path +import pytest + +from pre_commit_hooks.requirements_txt_fixer import fix_requirements_txt +from pre_commit_hooks.requirements_txt_fixer import Requirement + +# Input, expected return value, expected output +TESTS = ( + (b'foo\nbar\n', 1, b'bar\nfoo\n'), + (b'bar\nfoo\n', 0, b'bar\nfoo\n'), + (b'#comment1\nfoo\n#comment2\nbar\n', 1, b'#comment2\nbar\n#comment1\nfoo\n'), + (b'#comment1\nbar\n#comment2\nfoo\n', 0, b'#comment1\nbar\n#comment2\nfoo\n'), + (b'#comment\n\nfoo\nbar\n', 1, b'#comment\n\nbar\nfoo\n'), + (b'#comment\n\nbar\nfoo\n', 0, b'#comment\n\nbar\nfoo\n'), + (b'\nfoo\nbar\n', 1, b'bar\n\nfoo\n'), + (b'\nbar\nfoo\n', 0, b'\nbar\nfoo\n'), +) + + +@pytest.mark.parametrize(('input', 'expected_retval', 'output'), TESTS) +def test_integration(input, expected_retval, output, tmpdir): + path = os.path.join(tmpdir.strpath, 'file.txt') + + with open(path, 'wb') as file_obj: + file_obj.write(input) + + assert fix_requirements_txt([path]) == expected_retval + assert open(path, 'rb').read() == output + + +def test_requirement_object(): + top_of_file = Requirement() + top_of_file.comments.append('#foo') + top_of_file.value = b'\n' + + requirement_foo = Requirement() + requirement_foo.value = b'foo' + + requirement_bar = Requirement() + requirement_bar.value = b'bar' + + # This may look redundant, but we need to test both foo.__lt__(bar) and + # bar.__lt__(foo) + assert requirement_foo > top_of_file + assert top_of_file < requirement_foo + assert requirement_foo > requirement_bar + assert requirement_bar < requirement_foo