diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index 21c71f4..5703d66 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -12,7 +12,7 @@ FAIL = 1 class Requirement: UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?') - VERSION_MATCHER = re.compile(b'(?:={2,3}|!=|~=|>=?|<=?|@)\s*(?P[A-Za-z0-9./:]+)$') + VERSION_MATCHER = re.compile(br'(?:={2,3}|!=|~=|>=?|<=?|@)\s*(?P[A-Za-z0-9./:]+)$') # noqa: E501 UNTIL_SEP = re.compile(rb'[^;\s]+') def __init__(self) -> None: @@ -41,7 +41,9 @@ class Requirement: def has_version(self) -> bool: return self.extract_version() is not None - def extract_version(self): + def extract_version(self) -> str | None: + if not self.value: + return None matches = self.VERSION_MATCHER.search(self.value) if matches: self.version = matches.groups()[0].decode() @@ -145,7 +147,7 @@ def fix_requirements(f: IO[bytes], require_version: bool = False) -> int: outcome = PASS if len(missing_versions) > 0: - print("Missing versions in:", ", ".join(missing_versions)) + print('Missing versions in:', ', '.join(missing_versions)) outcome = FAIL if before_string != after_string: @@ -160,11 +162,12 @@ def fix_requirements(f: IO[bytes], require_version: bool = False) -> int: def main(argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') - parser.add_argument('-r', '--require_version', + parser.add_argument( + '-r', '--require_version', required=False, - help='Use this to require each requirement to include a version number', + help='Each requirement must include a version number', action='store_true', - default=False + default=False, ) args = parser.parse_args(argv)