diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index 8ce8ec6..2c640f5 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -36,6 +36,15 @@ class Requirement: return name[:m.start()] + @property + def sort_key(self) -> tuple[int, bytes]: + if self.name == b'--index-url': + return (0, self.name) + elif self.name == b'--extra-index-url': + return (1, self.name) + else: + return (2, self.name) + def __lt__(self, requirement: Requirement) -> bool: # \n means top of file comment, so always return True, # otherwise just do a string comparison with value. @@ -50,7 +59,7 @@ class Requirement: # with comments is kept) if self.name == requirement.name: return bool(self.comments) > bool(requirement.comments) - return self.name < requirement.name + return self.sort_key < requirement.sort_key def is_complete(self) -> bool: return ( diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py index c0d2c65..35f526c 100644 --- a/tests/requirements_txt_fixer_test.py +++ b/tests/requirements_txt_fixer_test.py @@ -107,6 +107,16 @@ from pre_commit_hooks.requirements_txt_fixer import Requirement PASS, b'a=2.0.0 \\\n --hash=sha256:abcd\nb==1.0.0\n', ), + + ( + b'--extra-index-url https://example-extra/simple\n' + b'--index-url https://example-main/simple\n' + b'requests==2.31.0\n', + FAIL, + b'--index-url https://example-main/simple\n' + b'--extra-index-url https://example-extra/simple\n' + b'requests==2.31.0\n', + ), ), ) def test_integration(input_s, expected_retval, output, tmpdir):