diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index 6190692..78103a1 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -1,4 +1,5 @@ import argparse +import re from typing import IO from typing import List from typing import Optional @@ -10,6 +11,9 @@ FAIL = 1 class Requirement: + UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?') + UNTIL_SEP = re.compile(rb'[^;\s]+') + def __init__(self) -> None: self.value: Optional[bytes] = None self.comments: List[bytes] = [] @@ -17,11 +21,20 @@ class Requirement: @property def name(self) -> bytes: assert self.value is not None, self.value + name = self.value.lower() for egg in (b'#egg=', b'&egg='): if egg in self.value: - return self.value.lower().partition(egg)[-1] + return name.partition(egg)[-1] - return self.value.lower().partition(b'==')[0] + m = self.UNTIL_SEP.match(name) + assert m is not None + + name = m.group() + m = self.UNTIL_COMPARISON.search(name) + if not m: + return name + + return name[:m.start()] def __lt__(self, requirement: 'Requirement') -> int: # \n means top of file comment, so always return True, diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py index 17a9a41..fae5a72 100644 --- a/tests/requirements_txt_fixer_test.py +++ b/tests/requirements_txt_fixer_test.py @@ -33,9 +33,28 @@ from pre_commit_hooks.requirements_txt_fixer import Requirement (b'\nfoo\nbar\n', FAIL, b'bar\n\nfoo\n'), (b'\nbar\nfoo\n', PASS, b'\nbar\nfoo\n'), ( - b'pyramid==1\npyramid-foo==2\n', - PASS, - b'pyramid==1\npyramid-foo==2\n', + b'pyramid-foo==1\npyramid>=2\n', + FAIL, + b'pyramid>=2\npyramid-foo==1\n', + ), + ( + b'a==1\n' + b'c>=1\n' + b'bbbb!=1\n' + b'c-a>=1;python_version>="3.6"\n' + b'e>=2\n' + b'd>2\n' + b'g<2\n' + b'f<=2\n', + FAIL, + b'a==1\n' + b'bbbb!=1\n' + b'c>=1\n' + b'c-a>=1;python_version>="3.6"\n' + b'd>2\n' + b'e>=2\n' + b'f<=2\n' + b'g<2\n', ), (b'ocflib\nDjango\nPyMySQL\n', FAIL, b'Django\nocflib\nPyMySQL\n'), (