diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index 41e1ffc..ffabf2a 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -3,6 +3,10 @@ from __future__ import print_function import argparse +PASS = 0 +FAIL = 1 + + class Requirement(object): def __init__(self): @@ -30,14 +34,14 @@ class Requirement(object): def fix_requirements(f): requirements = [] - before = list(f) + before = tuple(f) after = [] before_string = b''.join(before) # If the file is empty (i.e. only whitespace/newlines) exit early if before_string.strip() == b'': - return 0 + return PASS for line in before: # If the most recent requirement object has a value, then it's @@ -60,19 +64,18 @@ def fix_requirements(f): requirement.value = line for requirement in sorted(requirements): - for comment in requirement.comments: - after.append(comment) + after.extend(requirement.comments) after.append(requirement.value) after_string = b''.join(after) if before_string == after_string: - return 0 + return PASS else: f.seek(0) f.write(after_string) f.truncate() - return 1 + return FAIL def fix_requirements_txt(argv=None): @@ -80,7 +83,7 @@ def fix_requirements_txt(argv=None): parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) - retv = 0 + retv = PASS for arg in args.filenames: with open(arg, 'rb+') as file_obj: diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py index 33f6a47..3681cc6 100644 --- a/tests/requirements_txt_fixer_test.py +++ b/tests/requirements_txt_fixer_test.py @@ -1,33 +1,41 @@ import pytest +from pre_commit_hooks.requirements_txt_fixer import FAIL from pre_commit_hooks.requirements_txt_fixer import fix_requirements_txt +from pre_commit_hooks.requirements_txt_fixer import PASS from pre_commit_hooks.requirements_txt_fixer import Requirement -# Input, expected return value, expected output -TESTS = ( - (b'', 0, b''), - (b'\n', 0, b'\n'), - (b'foo\nbar\n', 1, b'bar\nfoo\n'), - (b'bar\nfoo\n', 0, b'bar\nfoo\n'), - (b'#comment1\nfoo\n#comment2\nbar\n', 1, b'#comment2\nbar\n#comment1\nfoo\n'), - (b'#comment1\nbar\n#comment2\nfoo\n', 0, b'#comment1\nbar\n#comment2\nfoo\n'), - (b'#comment\n\nfoo\nbar\n', 1, b'#comment\n\nbar\nfoo\n'), - (b'#comment\n\nbar\nfoo\n', 0, b'#comment\n\nbar\nfoo\n'), - (b'\nfoo\nbar\n', 1, b'bar\n\nfoo\n'), - (b'\nbar\nfoo\n', 0, b'\nbar\nfoo\n'), - (b'pyramid==1\npyramid-foo==2\n', 0, b'pyramid==1\npyramid-foo==2\n'), - (b'ocflib\nDjango\nPyMySQL\n', 1, b'Django\nocflib\nPyMySQL\n'), - (b'-e git+ssh://git_url@tag#egg=ocflib\nDjango\nPyMySQL\n', 1, b'Django\n-e git+ssh://git_url@tag#egg=ocflib\nPyMySQL\n'), + +@pytest.mark.parametrize( + ('input_s', 'expected_retval', 'output'), + ( + (b'', PASS, b''), + (b'\n', PASS, b'\n'), + (b'foo\nbar\n', FAIL, b'bar\nfoo\n'), + (b'bar\nfoo\n', PASS, b'bar\nfoo\n'), + (b'#comment1\nfoo\n#comment2\nbar\n', FAIL, b'#comment2\nbar\n#comment1\nfoo\n'), + (b'#comment1\nbar\n#comment2\nfoo\n', PASS, b'#comment1\nbar\n#comment2\nfoo\n'), + (b'#comment\n\nfoo\nbar\n', FAIL, b'#comment\n\nbar\nfoo\n'), + (b'#comment\n\nbar\nfoo\n', PASS, b'#comment\n\nbar\nfoo\n'), + (b'\nfoo\nbar\n', FAIL, b'bar\n\nfoo\n'), + (b'\nbar\nfoo\n', PASS, b'\nbar\nfoo\n'), + (b'pyramid==1\npyramid-foo==2\n', PASS, b'pyramid==1\npyramid-foo==2\n'), + (b'ocflib\nDjango\nPyMySQL\n', FAIL, b'Django\nocflib\nPyMySQL\n'), + ( + b'-e git+ssh://git_url@tag#egg=ocflib\nDjango\nPyMySQL\n', + FAIL, + b'Django\n-e git+ssh://git_url@tag#egg=ocflib\nPyMySQL\n' + ), + ) ) - - -@pytest.mark.parametrize(('input_s', 'expected_retval', 'output'), TESTS) def test_integration(input_s, expected_retval, output, tmpdir): path = tmpdir.join('file.txt') path.write_binary(input_s) - assert fix_requirements_txt([path.strpath]) == expected_retval + output_retval = fix_requirements_txt([path.strpath]) + assert path.read_binary() == output + assert output_retval == expected_retval def test_requirement_object():