diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index 5884394..21c71f4 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -12,10 +12,12 @@ FAIL = 1 class Requirement: UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?') + VERSION_MATCHER = re.compile(b'(?:={2,3}|!=|~=|>=?|<=?|@)\s*(?P[A-Za-z0-9./:]+)$') UNTIL_SEP = re.compile(rb'[^;\s]+') def __init__(self) -> None: self.value: bytes | None = None + self.version: str | None = None self.comments: list[bytes] = [] @property @@ -36,6 +38,17 @@ class Requirement: return name[:m.start()] + def has_version(self) -> bool: + return self.extract_version() is not None + + def extract_version(self): + matches = self.VERSION_MATCHER.search(self.value) + if matches: + self.version = matches.groups()[0].decode() + else: + self.version = None + return self.version + def __lt__(self, requirement: Requirement) -> bool: # \n means top of file comment, so always return True, # otherwise just do a string comparison with value. @@ -47,7 +60,7 @@ class Requirement: else: return self.name < requirement.name - def is_complete(self) -> bool: + def is_complete(self, require_version: bool = False) -> bool: return ( self.value is not None and not self.value.rstrip(b'\r\n').endswith(b'\\') @@ -60,7 +73,7 @@ class Requirement: self.value = value -def fix_requirements(f: IO[bytes]) -> int: +def fix_requirements(f: IO[bytes], require_version: bool = False) -> int: requirements: list[Requirement] = [] before = list(f) after: list[bytes] = [] @@ -113,33 +126,53 @@ def fix_requirements(f: IO[bytes]) -> int: if req.value != b'pkg-resources==0.0.0\n' ] + missing_versions = [] for requirement in sorted(requirements): after.extend(requirement.comments) assert requirement.value, requirement.value after.append(requirement.value) + if require_version and not requirement.has_version(): + missing_versions.append(requirement.value.decode().strip()) after.extend(rest) after_string = b''.join(after) - if before_string == after_string: - return PASS - else: + # If the version is required but missing, we return FAIL, + # but still write the fixes to the file, because the pip install + # will work even if the req file is missing versions. + # We could block the write if versions are missing, too. This is + # something we should discuss in the PR review. + outcome = PASS + + if len(missing_versions) > 0: + print("Missing versions in:", ", ".join(missing_versions)) + outcome = FAIL + + if before_string != after_string: f.seek(0) f.write(after_string) f.truncate() - return FAIL + outcome = FAIL + + return outcome def main(argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') + parser.add_argument('-r', '--require_version', + required=False, + help='Use this to require each requirement to include a version number', + action='store_true', + default=False + ) args = parser.parse_args(argv) retv = PASS for arg in args.filenames: with open(arg, 'rb+') as file_obj: - ret_for_file = fix_requirements(file_obj) + ret_for_file = fix_requirements(file_obj, args.require_version) if ret_for_file: print(f'Sorting {arg}') diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py index b725afa..69a9719 100644 --- a/tests/requirements_txt_fixer_test.py +++ b/tests/requirements_txt_fixer_test.py @@ -122,9 +122,18 @@ def test_requirement_object(): requirement_bar = Requirement() requirement_bar.value = b'bar' + requirement_baz = Requirement() + requirement_baz.value = b'baz>=1.2.3' + # This may look redundant, but we need to test both foo.__lt__(bar) and # bar.__lt__(foo) assert requirement_foo > top_of_file assert top_of_file < requirement_foo assert requirement_foo > requirement_bar assert requirement_bar < requirement_foo + + # Test the version extraction code + assert not requirement_foo.has_version() + assert not requirement_bar.has_version() + assert requirement_baz.has_version() + assert requirement_baz.version == '1.2.3'