diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py index 69a9719..2efaea6 100644 --- a/tests/requirements_txt_fixer_test.py +++ b/tests/requirements_txt_fixer_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import _pytest import pytest from pre_commit_hooks.requirements_txt_fixer import FAIL @@ -111,6 +112,39 @@ def test_integration(input_s, expected_retval, output, tmpdir): assert output_retval == expected_retval +def test_require_version(tmpdir: _pytest._py.path.LocalPath) -> None: + # default: don't require version + path = tmpdir.join('file.txt') + input_s = ( + b'a==1\n' + b'bbbb\n' + b'c>=1\n' + ) + path.write_binary(input_s) + output_retval = main([str(path)]) + assert output_retval == PASS + + # require version, and all the versions are specified + input_s = ( + b'a==1\n' + b'bbbb~=2.0\n' + b'c>=1\n' + ) + path.write_binary(input_s) + output_retval = main([str(path), '--require_version']) + assert output_retval == PASS + + # require version, and some of the versions are missing + input_s = ( + b'a\n' + b'bbbb~=2.0\n' + b'c>=1\n' + ) + path.write_binary(input_s) + output_retval = main([str(path), '--require_version']) + assert output_retval == FAIL + + def test_requirement_object(): top_of_file = Requirement() top_of_file.comments.append(b'#foo') @@ -125,6 +159,8 @@ def test_requirement_object(): requirement_baz = Requirement() requirement_baz.value = b'baz>=1.2.3' + requirement_empty = Requirement() + # This may look redundant, but we need to test both foo.__lt__(bar) and # bar.__lt__(foo) assert requirement_foo > top_of_file @@ -133,6 +169,7 @@ def test_requirement_object(): assert requirement_bar < requirement_foo # Test the version extraction code + assert requirement_empty.extract_version() is None assert not requirement_foo.has_version() assert not requirement_bar.has_version() assert requirement_baz.has_version()