From bbcd31e00074aeab5cc15abaa2609ae38b8c398d Mon Sep 17 00:00:00 2001 From: Aniket Bhatnagar Date: Thu, 7 May 2020 21:02:12 +0100 Subject: [PATCH] Handled multiline dependencies --- pre_commit_hooks/requirements_txt_fixer.py | 16 ++++++++++++++-- tests/requirements_txt_fixer_test.py | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index dc41815..6190692 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -34,6 +34,18 @@ class Requirement: else: return self.name < requirement.name + def is_complete(self) -> bool: + return ( + self.value is not None and + not self.value.rstrip(b'\r\n').endswith(b'\\') + ) + + def append_value(self, value: bytes) -> None: + if self.value is not None: + self.value += value + else: + self.value = value + def fix_requirements(f: IO[bytes]) -> int: requirements: List[Requirement] = [] @@ -55,7 +67,7 @@ def fix_requirements(f: IO[bytes]) -> int: # If the most recent requirement object has a value, then it's # time to start building the next requirement object. - if not len(requirements) or requirements[-1].value is not None: + if not len(requirements) or requirements[-1].is_complete(): requirements.append(Requirement()) requirement = requirements[-1] @@ -73,7 +85,7 @@ def fix_requirements(f: IO[bytes]) -> int: elif line.startswith(b'#') or line.strip() == b'': requirement.comments.append(line) else: - requirement.value = line + requirement.append_value(line) # if a file ends in a comment, preserve it at the end if requirements[-1].value is None: diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py index 7b9b07d..17a9a41 100644 --- a/tests/requirements_txt_fixer_test.py +++ b/tests/requirements_txt_fixer_test.py @@ -50,6 +50,24 @@ from pre_commit_hooks.requirements_txt_fixer import Requirement FAIL, b'Django\nijk\ngit+ssh://git_url@tag#egg=ocflib\n', ), + ( + b'b==1.0.0\n' + b'c=2.0.0 \\\n' + b' --hash=sha256:abcd\n' + b'a=3.0.0 \\\n' + b' --hash=sha256:a1b1c1d1', + FAIL, + b'a=3.0.0 \\\n' + b' --hash=sha256:a1b1c1d1\n' + b'b==1.0.0\n' + b'c=2.0.0 \\\n' + b' --hash=sha256:abcd\n', + ), + ( + b'a=2.0.0 \\\n --hash=sha256:abcd\nb==1.0.0\n', + PASS, + b'a=2.0.0 \\\n --hash=sha256:abcd\nb==1.0.0\n', + ), ), ) def test_integration(input_s, expected_retval, output, tmpdir):