mirror of
https://github.com/pre-commit/pre-commit-hooks.git
synced 2026-03-30 10:16:54 +00:00
Added option to check if the packages in the requirements file(s) include a version string. This lets you enforce a requirement that all your dependency versions be specified to avoid breakage when a package's version changes under you.
This commit is contained in:
parent
dd97157e54
commit
57bf19d519
2 changed files with 49 additions and 7 deletions
|
|
@ -12,10 +12,12 @@ FAIL = 1
|
|||
|
||||
class Requirement:
|
||||
UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?')
|
||||
VERSION_MATCHER = re.compile(b'(?:={2,3}|!=|~=|>=?|<=?|@)\s*(?P<version>[A-Za-z0-9./:]+)$')
|
||||
UNTIL_SEP = re.compile(rb'[^;\s]+')
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.value: bytes | None = None
|
||||
self.version: str | None = None
|
||||
self.comments: list[bytes] = []
|
||||
|
||||
@property
|
||||
|
|
@ -36,6 +38,17 @@ class Requirement:
|
|||
|
||||
return name[:m.start()]
|
||||
|
||||
def has_version(self) -> bool:
|
||||
return self.extract_version() is not None
|
||||
|
||||
def extract_version(self):
|
||||
matches = self.VERSION_MATCHER.search(self.value)
|
||||
if matches:
|
||||
self.version = matches.groups()[0].decode()
|
||||
else:
|
||||
self.version = None
|
||||
return self.version
|
||||
|
||||
def __lt__(self, requirement: Requirement) -> bool:
|
||||
# \n means top of file comment, so always return True,
|
||||
# otherwise just do a string comparison with value.
|
||||
|
|
@ -47,7 +60,7 @@ class Requirement:
|
|||
else:
|
||||
return self.name < requirement.name
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
def is_complete(self, require_version: bool = False) -> bool:
|
||||
return (
|
||||
self.value is not None and
|
||||
not self.value.rstrip(b'\r\n').endswith(b'\\')
|
||||
|
|
@ -60,7 +73,7 @@ class Requirement:
|
|||
self.value = value
|
||||
|
||||
|
||||
def fix_requirements(f: IO[bytes]) -> int:
|
||||
def fix_requirements(f: IO[bytes], require_version: bool = False) -> int:
|
||||
requirements: list[Requirement] = []
|
||||
before = list(f)
|
||||
after: list[bytes] = []
|
||||
|
|
@ -113,33 +126,53 @@ def fix_requirements(f: IO[bytes]) -> int:
|
|||
if req.value != b'pkg-resources==0.0.0\n'
|
||||
]
|
||||
|
||||
missing_versions = []
|
||||
for requirement in sorted(requirements):
|
||||
after.extend(requirement.comments)
|
||||
assert requirement.value, requirement.value
|
||||
after.append(requirement.value)
|
||||
if require_version and not requirement.has_version():
|
||||
missing_versions.append(requirement.value.decode().strip())
|
||||
after.extend(rest)
|
||||
|
||||
after_string = b''.join(after)
|
||||
|
||||
if before_string == after_string:
|
||||
return PASS
|
||||
else:
|
||||
# If the version is required but missing, we return FAIL,
|
||||
# but still write the fixes to the file, because the pip install
|
||||
# will work even if the req file is missing versions.
|
||||
# We could block the write if versions are missing, too. This is
|
||||
# something we should discuss in the PR review.
|
||||
outcome = PASS
|
||||
|
||||
if len(missing_versions) > 0:
|
||||
print("Missing versions in:", ", ".join(missing_versions))
|
||||
outcome = FAIL
|
||||
|
||||
if before_string != after_string:
|
||||
f.seek(0)
|
||||
f.write(after_string)
|
||||
f.truncate()
|
||||
return FAIL
|
||||
outcome = FAIL
|
||||
|
||||
return outcome
|
||||
|
||||
|
||||
def main(argv: Sequence[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
|
||||
parser.add_argument('-r', '--require_version',
|
||||
required=False,
|
||||
help='Use this to require each requirement to include a version number',
|
||||
action='store_true',
|
||||
default=False
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
retv = PASS
|
||||
|
||||
for arg in args.filenames:
|
||||
with open(arg, 'rb+') as file_obj:
|
||||
ret_for_file = fix_requirements(file_obj)
|
||||
ret_for_file = fix_requirements(file_obj, args.require_version)
|
||||
|
||||
if ret_for_file:
|
||||
print(f'Sorting {arg}')
|
||||
|
|
|
|||
|
|
@ -122,9 +122,18 @@ def test_requirement_object():
|
|||
requirement_bar = Requirement()
|
||||
requirement_bar.value = b'bar'
|
||||
|
||||
requirement_baz = Requirement()
|
||||
requirement_baz.value = b'baz>=1.2.3'
|
||||
|
||||
# This may look redundant, but we need to test both foo.__lt__(bar) and
|
||||
# bar.__lt__(foo)
|
||||
assert requirement_foo > top_of_file
|
||||
assert top_of_file < requirement_foo
|
||||
assert requirement_foo > requirement_bar
|
||||
assert requirement_bar < requirement_foo
|
||||
|
||||
# Test the version extraction code
|
||||
assert not requirement_foo.has_version()
|
||||
assert not requirement_bar.has_version()
|
||||
assert requirement_baz.has_version()
|
||||
assert requirement_baz.version == '1.2.3'
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue