Merge pull request #360 from pre-commit/mypy

Apply typing to all of pre-commit-hooks
This commit is contained in:
Anthony Sottile 2019-02-02 10:47:43 -08:00 committed by GitHub
commit 634383cffd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
54 changed files with 401 additions and 264 deletions

9
.gitignore vendored
View file

@ -1,16 +1,11 @@
*.egg-info *.egg-info
*.iml
*.py[co] *.py[co]
.*.sw[a-z] .*.sw[a-z]
.pytest_cache
.coverage .coverage
.idea
.project
.pydevproject
.tox .tox
.venv.touch .venv.touch
/.mypy_cache
/.pytest_cache
/venv* /venv*
coverage-html coverage-html
dist dist
# SublimeText project/workspace files
*.sublime-*

View file

@ -27,7 +27,7 @@ repos:
rev: v1.3.5 rev: v1.3.5
hooks: hooks:
- id: reorder-python-imports - id: reorder-python-imports
language_version: python2.7 language_version: python3
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v1.11.1 rev: v1.11.1
hooks: hooks:
@ -36,3 +36,8 @@ repos:
rev: v0.7.1 rev: v0.7.1
hooks: hooks:
- id: add-trailing-comma - id: add-trailing-comma
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.660
hooks:
- id: mypy
language_version: python3

View file

@ -1,3 +1,4 @@
dist: xenial
language: python language: python
matrix: matrix:
include: # These should match the tox env list include: # These should match the tox env list
@ -6,9 +7,8 @@ matrix:
python: 3.6 python: 3.6
- env: TOXENV=py37 - env: TOXENV=py37
python: 3.7 python: 3.7
dist: xenial
- env: TOXENV=pypy - env: TOXENV=pypy
python: pypy-5.7.1 python: pypy2.7-5.10.0
install: pip install coveralls tox install: pip install coveralls tox
script: tox script: tox
before_install: before_install:

View file

@ -4,7 +4,9 @@ import io
import os.path import os.path
import shutil import shutil
import tarfile import tarfile
from urllib.request import urlopen import urllib.request
from typing import cast
from typing import IO
DOWNLOAD_PATH = ( DOWNLOAD_PATH = (
'https://github.com/github/git-lfs/releases/download/' 'https://github.com/github/git-lfs/releases/download/'
@ -15,7 +17,7 @@ DEST_PATH = '/tmp/git-lfs/git-lfs'
DEST_DIR = os.path.dirname(DEST_PATH) DEST_DIR = os.path.dirname(DEST_PATH)
def main(): def main(): # type: () -> int
if ( if (
os.path.exists(DEST_PATH) and os.path.exists(DEST_PATH) and
os.path.isfile(DEST_PATH) and os.path.isfile(DEST_PATH) and
@ -27,12 +29,13 @@ def main():
shutil.rmtree(DEST_DIR, ignore_errors=True) shutil.rmtree(DEST_DIR, ignore_errors=True)
os.makedirs(DEST_DIR, exist_ok=True) os.makedirs(DEST_DIR, exist_ok=True)
contents = io.BytesIO(urlopen(DOWNLOAD_PATH).read()) contents = io.BytesIO(urllib.request.urlopen(DOWNLOAD_PATH).read())
with tarfile.open(fileobj=contents) as tar: with tarfile.open(fileobj=contents) as tar:
with tar.extractfile(PATH_IN_TAR) as src_file: with cast(IO[bytes], tar.extractfile(PATH_IN_TAR)) as src_file:
with open(DEST_PATH, 'wb') as dest_file: with open(DEST_PATH, 'wb') as dest_file:
shutil.copyfileobj(src_file, dest_file) shutil.copyfileobj(src_file, dest_file)
os.chmod(DEST_PATH, 0o755) os.chmod(DEST_PATH, 0o755)
return 0
if __name__ == '__main__': if __name__ == '__main__':

12
mypy.ini Normal file
View file

@ -0,0 +1,12 @@
[mypy]
check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
no_implicit_optional = true
[mypy-testing.*]
disallow_untyped_defs = false
[mypy-tests.*]
disallow_untyped_defs = false

View file

@ -3,7 +3,7 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
def main(argv=None): def main(): # type: () -> int
raise SystemExit( raise SystemExit(
'autopep8-wrapper is deprecated. Instead use autopep8 directly via ' 'autopep8-wrapper is deprecated. Instead use autopep8 directly via '
'https://github.com/pre-commit/mirrors-autopep8', 'https://github.com/pre-commit/mirrors-autopep8',

View file

@ -7,13 +7,17 @@ import argparse
import json import json
import math import math
import os import os
from typing import Iterable
from typing import Optional
from typing import Sequence
from typing import Set
from pre_commit_hooks.util import added_files from pre_commit_hooks.util import added_files
from pre_commit_hooks.util import CalledProcessError from pre_commit_hooks.util import CalledProcessError
from pre_commit_hooks.util import cmd_output from pre_commit_hooks.util import cmd_output
def lfs_files(): def lfs_files(): # type: () -> Set[str]
try: try:
# Introduced in git-lfs 2.2.0, first working in 2.2.1 # Introduced in git-lfs 2.2.0, first working in 2.2.1
lfs_ret = cmd_output('git', 'lfs', 'status', '--json') lfs_ret = cmd_output('git', 'lfs', 'status', '--json')
@ -24,6 +28,7 @@ def lfs_files():
def find_large_added_files(filenames, maxkb): def find_large_added_files(filenames, maxkb):
# type: (Iterable[str], int) -> int
# Find all added files that are also in the list of files pre-commit tells # Find all added files that are also in the list of files pre-commit tells
# us about # us about
filenames = (added_files() & set(filenames)) - lfs_files() filenames = (added_files() & set(filenames)) - lfs_files()
@ -38,7 +43,7 @@ def find_large_added_files(filenames, maxkb):
return retv return retv
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'filenames', nargs='*', 'filenames', nargs='*',

View file

@ -7,9 +7,11 @@ import ast
import platform import platform
import sys import sys
import traceback import traceback
from typing import Optional
from typing import Sequence
def check_ast(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*') parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -34,4 +36,4 @@ def check_ast(argv=None):
if __name__ == '__main__': if __name__ == '__main__':
exit(check_ast()) exit(main())

View file

@ -4,6 +4,10 @@ import argparse
import ast import ast
import collections import collections
import sys import sys
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
BUILTIN_TYPES = { BUILTIN_TYPES = {
@ -22,14 +26,17 @@ BuiltinTypeCall = collections.namedtuple('BuiltinTypeCall', ['name', 'line', 'co
class BuiltinTypeVisitor(ast.NodeVisitor): class BuiltinTypeVisitor(ast.NodeVisitor):
def __init__(self, ignore=None, allow_dict_kwargs=True): def __init__(self, ignore=None, allow_dict_kwargs=True):
self.builtin_type_calls = [] # type: (Optional[Sequence[str]], bool) -> None
self.builtin_type_calls = [] # type: List[BuiltinTypeCall]
self.ignore = set(ignore) if ignore else set() self.ignore = set(ignore) if ignore else set()
self.allow_dict_kwargs = allow_dict_kwargs self.allow_dict_kwargs = allow_dict_kwargs
def _check_dict_call(self, node): def _check_dict_call(self, node): # type: (ast.Call) -> bool
return self.allow_dict_kwargs and (getattr(node, 'kwargs', None) or getattr(node, 'keywords', None)) return self.allow_dict_kwargs and (getattr(node, 'kwargs', None) or getattr(node, 'keywords', None))
def visit_Call(self, node): def visit_Call(self, node): # type: (ast.Call) -> None
if not isinstance(node.func, ast.Name): if not isinstance(node.func, ast.Name):
# Ignore functions that are object attributes (`foo.bar()`). # Ignore functions that are object attributes (`foo.bar()`).
# Assume that if the user calls `builtins.list()`, they know what # Assume that if the user calls `builtins.list()`, they know what
@ -47,6 +54,7 @@ class BuiltinTypeVisitor(ast.NodeVisitor):
def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_kwargs=True): def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_kwargs=True):
# type: (str, Optional[Sequence[str]], bool) -> List[BuiltinTypeCall]
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
tree = ast.parse(f.read(), filename=filename) tree = ast.parse(f.read(), filename=filename)
visitor = BuiltinTypeVisitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs) visitor = BuiltinTypeVisitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs)
@ -54,24 +62,22 @@ def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_k
return visitor.builtin_type_calls return visitor.builtin_type_calls
def parse_args(argv): def parse_ignore(value): # type: (str) -> Set[str]
def parse_ignore(value): return set(value.split(','))
return set(value.split(','))
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*') parser.add_argument('filenames', nargs='*')
parser.add_argument('--ignore', type=parse_ignore, default=set()) parser.add_argument('--ignore', type=parse_ignore, default=set())
allow_dict_kwargs = parser.add_mutually_exclusive_group(required=False) mutex = parser.add_mutually_exclusive_group(required=False)
allow_dict_kwargs.add_argument('--allow-dict-kwargs', action='store_true') mutex.add_argument('--allow-dict-kwargs', action='store_true')
allow_dict_kwargs.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false') mutex.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false')
allow_dict_kwargs.set_defaults(allow_dict_kwargs=True) mutex.set_defaults(allow_dict_kwargs=True)
return parser.parse_args(argv) args = parser.parse_args(argv)
def main(argv=None):
args = parse_args(argv)
rc = 0 rc = 0
for filename in args.filenames: for filename in args.filenames:
calls = check_file_for_builtin_type_constructors( calls = check_file_for_builtin_type_constructors(

View file

@ -3,9 +3,11 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import argparse import argparse
from typing import Optional
from typing import Sequence
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check') parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv) args = parser.parse_args(argv)

View file

@ -3,16 +3,20 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import argparse import argparse
from typing import Iterable
from typing import Optional
from typing import Sequence
from typing import Set
from pre_commit_hooks.util import added_files from pre_commit_hooks.util import added_files
from pre_commit_hooks.util import cmd_output from pre_commit_hooks.util import cmd_output
def lower_set(iterable): def lower_set(iterable): # type: (Iterable[str]) -> Set[str]
return {x.lower() for x in iterable} return {x.lower() for x in iterable}
def find_conflicting_filenames(filenames): def find_conflicting_filenames(filenames): # type: (Sequence[str]) -> int
repo_files = set(cmd_output('git', 'ls-files').splitlines()) repo_files = set(cmd_output('git', 'ls-files').splitlines())
relevant_files = set(filenames) | added_files() relevant_files = set(filenames) | added_files()
repo_files -= relevant_files repo_files -= relevant_files
@ -41,7 +45,7 @@ def find_conflicting_filenames(filenames):
return retv return retv
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'filenames', nargs='*', 'filenames', nargs='*',

View file

@ -5,6 +5,8 @@ from __future__ import unicode_literals
import argparse import argparse
import io import io
import tokenize import tokenize
from typing import Optional
from typing import Sequence
NON_CODE_TOKENS = frozenset(( NON_CODE_TOKENS = frozenset((
@ -13,6 +15,7 @@ NON_CODE_TOKENS = frozenset((
def check_docstring_first(src, filename='<unknown>'): def check_docstring_first(src, filename='<unknown>'):
# type: (str, str) -> int
"""Returns nonzero if the source has what looks like a docstring that is """Returns nonzero if the source has what looks like a docstring that is
not at the beginning of the source. not at the beginning of the source.
@ -50,7 +53,7 @@ def check_docstring_first(src, filename='<unknown>'):
return 0 return 0
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*') parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv) args = parser.parse_args(argv)

View file

@ -6,9 +6,11 @@ from __future__ import unicode_literals
import argparse import argparse
import pipes import pipes
import sys import sys
from typing import Optional
from typing import Sequence
def check_has_shebang(path): def check_has_shebang(path): # type: (str) -> int
with open(path, 'rb') as f: with open(path, 'rb') as f:
first_bytes = f.read(2) first_bytes = f.read(2)
@ -27,7 +29,7 @@ def check_has_shebang(path):
return 0 return 0
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('filenames', nargs='*') parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -38,3 +40,7 @@ def main(argv=None):
retv |= check_has_shebang(filename) retv |= check_has_shebang(filename)
return retv return retv
if __name__ == '__main__':
exit(main())

View file

@ -4,9 +4,11 @@ import argparse
import io import io
import json import json
import sys import sys
from typing import Optional
from typing import Sequence
def check_json(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='JSON filenames to check.') parser.add_argument('filenames', nargs='*', help='JSON filenames to check.')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -22,4 +24,4 @@ def check_json(argv=None):
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(check_json()) sys.exit(main())

View file

@ -2,6 +2,9 @@ from __future__ import print_function
import argparse import argparse
import os.path import os.path
from typing import Optional
from typing import Sequence
CONFLICT_PATTERNS = [ CONFLICT_PATTERNS = [
b'<<<<<<< ', b'<<<<<<< ',
@ -12,7 +15,7 @@ CONFLICT_PATTERNS = [
WARNING_MSG = 'Merge conflict string "{0}" found in {1}:{2}' WARNING_MSG = 'Merge conflict string "{0}" found in {1}:{2}'
def is_in_merge(): def is_in_merge(): # type: () -> int
return ( return (
os.path.exists(os.path.join('.git', 'MERGE_MSG')) and os.path.exists(os.path.join('.git', 'MERGE_MSG')) and
( (
@ -23,7 +26,7 @@ def is_in_merge():
) )
def detect_merge_conflict(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*') parser.add_argument('filenames', nargs='*')
parser.add_argument('--assume-in-merge', action='store_true') parser.add_argument('--assume-in-merge', action='store_true')
@ -47,4 +50,4 @@ def detect_merge_conflict(argv=None):
if __name__ == '__main__': if __name__ == '__main__':
exit(detect_merge_conflict()) exit(main())

View file

@ -4,9 +4,11 @@ from __future__ import unicode_literals
import argparse import argparse
import os.path import os.path
from typing import Optional
from typing import Sequence
def check_symlinks(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser(description='Checks for broken symlinks.') parser = argparse.ArgumentParser(description='Checks for broken symlinks.')
parser.add_argument('filenames', nargs='*', help='Filenames to check') parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -25,4 +27,4 @@ def check_symlinks(argv=None):
if __name__ == '__main__': if __name__ == '__main__':
exit(check_symlinks()) exit(main())

View file

@ -5,6 +5,8 @@ from __future__ import unicode_literals
import argparse import argparse
import re import re
import sys import sys
from typing import Optional
from typing import Sequence
GITHUB_NON_PERMALINK = re.compile( GITHUB_NON_PERMALINK = re.compile(
@ -12,7 +14,7 @@ GITHUB_NON_PERMALINK = re.compile(
) )
def _check_filename(filename): def _check_filename(filename): # type: (str) -> int
retv = 0 retv = 0
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
for i, line in enumerate(f, 1): for i, line in enumerate(f, 1):
@ -24,7 +26,7 @@ def _check_filename(filename):
return retv return retv
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*') parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv) args = parser.parse_args(argv)

View file

@ -5,10 +5,12 @@ from __future__ import unicode_literals
import argparse import argparse
import io import io
import sys import sys
import xml.sax import xml.sax.handler
from typing import Optional
from typing import Sequence
def check_xml(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='XML filenames to check.') parser.add_argument('filenames', nargs='*', help='XML filenames to check.')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -17,7 +19,7 @@ def check_xml(argv=None):
for filename in args.filenames: for filename in args.filenames:
try: try:
with io.open(filename, 'rb') as xml_file: with io.open(filename, 'rb') as xml_file:
xml.sax.parse(xml_file, xml.sax.ContentHandler()) xml.sax.parse(xml_file, xml.sax.handler.ContentHandler())
except xml.sax.SAXException as exc: except xml.sax.SAXException as exc:
print('{}: Failed to xml parse ({})'.format(filename, exc)) print('{}: Failed to xml parse ({})'.format(filename, exc))
retval = 1 retval = 1
@ -25,4 +27,4 @@ def check_xml(argv=None):
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(check_xml()) sys.exit(main())

View file

@ -3,22 +3,26 @@ from __future__ import print_function
import argparse import argparse
import collections import collections
import sys import sys
from typing import Any
from typing import Generator
from typing import Optional
from typing import Sequence
import ruamel.yaml import ruamel.yaml
yaml = ruamel.yaml.YAML(typ='safe') yaml = ruamel.yaml.YAML(typ='safe')
def _exhaust(gen): def _exhaust(gen): # type: (Generator[str, None, None]) -> None
for _ in gen: for _ in gen:
pass pass
def _parse_unsafe(*args, **kwargs): def _parse_unsafe(*args, **kwargs): # type: (*Any, **Any) -> None
_exhaust(yaml.parse(*args, **kwargs)) _exhaust(yaml.parse(*args, **kwargs))
def _load_all(*args, **kwargs): def _load_all(*args, **kwargs): # type: (*Any, **Any) -> None
_exhaust(yaml.load_all(*args, **kwargs)) _exhaust(yaml.load_all(*args, **kwargs))
@ -31,7 +35,7 @@ LOAD_FNS = {
} }
def check_yaml(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'-m', '--multi', '--allow-multiple-documents', action='store_true', '-m', '--multi', '--allow-multiple-documents', action='store_true',
@ -63,4 +67,4 @@ def check_yaml(argv=None):
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(check_yaml()) sys.exit(main())

View file

@ -5,6 +5,9 @@ import argparse
import ast import ast
import collections import collections
import traceback import traceback
from typing import List
from typing import Optional
from typing import Sequence
DEBUG_STATEMENTS = {'pdb', 'ipdb', 'pudb', 'q', 'rdb'} DEBUG_STATEMENTS = {'pdb', 'ipdb', 'pudb', 'q', 'rdb'}
@ -12,21 +15,21 @@ Debug = collections.namedtuple('Debug', ('line', 'col', 'name', 'reason'))
class DebugStatementParser(ast.NodeVisitor): class DebugStatementParser(ast.NodeVisitor):
def __init__(self): def __init__(self): # type: () -> None
self.breakpoints = [] self.breakpoints = [] # type: List[Debug]
def visit_Import(self, node): def visit_Import(self, node): # type: (ast.Import) -> None
for name in node.names: for name in node.names:
if name.name in DEBUG_STATEMENTS: if name.name in DEBUG_STATEMENTS:
st = Debug(node.lineno, node.col_offset, name.name, 'imported') st = Debug(node.lineno, node.col_offset, name.name, 'imported')
self.breakpoints.append(st) self.breakpoints.append(st)
def visit_ImportFrom(self, node): def visit_ImportFrom(self, node): # type: (ast.ImportFrom) -> None
if node.module in DEBUG_STATEMENTS: if node.module in DEBUG_STATEMENTS:
st = Debug(node.lineno, node.col_offset, node.module, 'imported') st = Debug(node.lineno, node.col_offset, node.module, 'imported')
self.breakpoints.append(st) self.breakpoints.append(st)
def visit_Call(self, node): def visit_Call(self, node): # type: (ast.Call) -> None
"""python3.7+ breakpoint()""" """python3.7+ breakpoint()"""
if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint': if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint':
st = Debug(node.lineno, node.col_offset, node.func.id, 'called') st = Debug(node.lineno, node.col_offset, node.func.id, 'called')
@ -34,7 +37,7 @@ class DebugStatementParser(ast.NodeVisitor):
self.generic_visit(node) self.generic_visit(node)
def check_file(filename): def check_file(filename): # type: (str) -> int
try: try:
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
ast_obj = ast.parse(f.read(), filename=filename) ast_obj = ast.parse(f.read(), filename=filename)
@ -58,7 +61,7 @@ def check_file(filename):
return int(bool(visitor.breakpoints)) return int(bool(visitor.breakpoints))
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to run') parser.add_argument('filenames', nargs='*', help='Filenames to run')
args = parser.parse_args(argv) args = parser.parse_args(argv)

View file

@ -3,11 +3,16 @@ from __future__ import unicode_literals
import argparse import argparse
import os import os
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from six.moves import configparser from six.moves import configparser
def get_aws_credential_files_from_env(): def get_aws_credential_files_from_env(): # type: () -> Set[str]
"""Extract credential file paths from environment variables.""" """Extract credential file paths from environment variables."""
files = set() files = set()
for env_var in ( for env_var in (
@ -19,7 +24,7 @@ def get_aws_credential_files_from_env():
return files return files
def get_aws_secrets_from_env(): def get_aws_secrets_from_env(): # type: () -> Set[str]
"""Extract AWS secrets from environment variables.""" """Extract AWS secrets from environment variables."""
keys = set() keys = set()
for env_var in ( for env_var in (
@ -30,7 +35,7 @@ def get_aws_secrets_from_env():
return keys return keys
def get_aws_secrets_from_file(credentials_file): def get_aws_secrets_from_file(credentials_file): # type: (str) -> Set[str]
"""Extract AWS secrets from configuration files. """Extract AWS secrets from configuration files.
Read an ini-style configuration file and return a set with all found AWS Read an ini-style configuration file and return a set with all found AWS
@ -62,6 +67,7 @@ def get_aws_secrets_from_file(credentials_file):
def check_file_for_aws_keys(filenames, keys): def check_file_for_aws_keys(filenames, keys):
# type: (Sequence[str], Set[str]) -> List[Dict[str, str]]
"""Check if files contain AWS secrets. """Check if files contain AWS secrets.
Return a list of all files containing AWS secrets and keys found, with all Return a list of all files containing AWS secrets and keys found, with all
@ -82,7 +88,7 @@ def check_file_for_aws_keys(filenames, keys):
return bad_files return bad_files
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='+', help='Filenames to run') parser.add_argument('filenames', nargs='+', help='Filenames to run')
parser.add_argument( parser.add_argument(
@ -111,7 +117,7 @@ def main(argv=None):
# of files to to gather AWS secrets from. # of files to to gather AWS secrets from.
credential_files |= get_aws_credential_files_from_env() credential_files |= get_aws_credential_files_from_env()
keys = set() keys = set() # type: Set[str]
for credential_file in credential_files: for credential_file in credential_files:
keys |= get_aws_secrets_from_file(credential_file) keys |= get_aws_secrets_from_file(credential_file)

View file

@ -2,6 +2,8 @@ from __future__ import print_function
import argparse import argparse
import sys import sys
from typing import Optional
from typing import Sequence
BLACKLIST = [ BLACKLIST = [
b'BEGIN RSA PRIVATE KEY', b'BEGIN RSA PRIVATE KEY',
@ -15,7 +17,7 @@ BLACKLIST = [
] ]
def detect_private_key(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check') parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -37,4 +39,4 @@ def detect_private_key(argv=None):
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(detect_private_key()) sys.exit(main())

View file

@ -4,9 +4,12 @@ from __future__ import unicode_literals
import argparse import argparse
import os import os
import sys import sys
from typing import IO
from typing import Optional
from typing import Sequence
def fix_file(file_obj): def fix_file(file_obj): # type: (IO[bytes]) -> int
# Test for newline at end of file # Test for newline at end of file
# Empty files will throw IOError here # Empty files will throw IOError here
try: try:
@ -49,7 +52,7 @@ def fix_file(file_obj):
return 0 return 0
def end_of_file_fixer(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix') parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -68,4 +71,4 @@ def end_of_file_fixer(argv=None):
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(end_of_file_fixer()) sys.exit(main())

View file

@ -12,12 +12,15 @@ conflicts and keep the file nicely ordered.
from __future__ import print_function from __future__ import print_function
import argparse import argparse
from typing import IO
from typing import Optional
from typing import Sequence
PASS = 0 PASS = 0
FAIL = 1 FAIL = 1
def sort_file_contents(f): def sort_file_contents(f): # type: (IO[bytes]) -> int
before = list(f) before = list(f)
after = sorted([line.strip(b'\n\r') for line in before if line.strip()]) after = sorted([line.strip(b'\n\r') for line in before if line.strip()])
@ -33,7 +36,7 @@ def sort_file_contents(f):
return FAIL return FAIL
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='+', help='Files to sort') parser.add_argument('filenames', nargs='+', help='Files to sort')
args = parser.parse_args(argv) args = parser.parse_args(argv)

View file

@ -4,11 +4,15 @@ from __future__ import unicode_literals
import argparse import argparse
import collections import collections
from typing import IO
from typing import Optional
from typing import Sequence
from typing import Union
DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n' DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n'
def has_coding(line): def has_coding(line): # type: (bytes) -> bool
if not line.strip(): if not line.strip():
return False return False
return ( return (
@ -33,15 +37,16 @@ class ExpectedContents(collections.namedtuple(
__slots__ = () __slots__ = ()
@property @property
def has_any_pragma(self): def has_any_pragma(self): # type: () -> bool
return self.pragma_status is not False return self.pragma_status is not False
def is_expected_pragma(self, remove): def is_expected_pragma(self, remove): # type: (bool) -> bool
expected_pragma_status = not remove expected_pragma_status = not remove
return self.pragma_status is expected_pragma_status return self.pragma_status is expected_pragma_status
def _get_expected_contents(first_line, second_line, rest, expected_pragma): def _get_expected_contents(first_line, second_line, rest, expected_pragma):
# type: (bytes, bytes, bytes, bytes) -> ExpectedContents
if first_line.startswith(b'#!'): if first_line.startswith(b'#!'):
shebang = first_line shebang = first_line
potential_coding = second_line potential_coding = second_line
@ -51,7 +56,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma):
rest = second_line + rest rest = second_line + rest
if potential_coding == expected_pragma: if potential_coding == expected_pragma:
pragma_status = True pragma_status = True # type: Optional[bool]
elif has_coding(potential_coding): elif has_coding(potential_coding):
pragma_status = None pragma_status = None
else: else:
@ -64,6 +69,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma):
def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA): def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
# type: (IO[bytes], bool, bytes) -> int
expected = _get_expected_contents( expected = _get_expected_contents(
f.readline(), f.readline(), f.read(), expected_pragma, f.readline(), f.readline(), f.read(), expected_pragma,
) )
@ -93,17 +99,17 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
return 1 return 1
def _normalize_pragma(pragma): def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes
if not isinstance(pragma, bytes): if not isinstance(pragma, bytes):
pragma = pragma.encode('UTF-8') pragma = pragma.encode('UTF-8')
return pragma.rstrip() + b'\n' return pragma.rstrip() + b'\n'
def _to_disp(pragma): def _to_disp(pragma): # type: (bytes) -> str
return pragma.decode().rstrip() return pragma.decode().rstrip()
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser('Fixes the encoding pragma of python files') parser = argparse.ArgumentParser('Fixes the encoding pragma of python files')
parser.add_argument('filenames', nargs='*', help='Filenames to fix') parser.add_argument('filenames', nargs='*', help='Filenames to fix')
parser.add_argument( parser.add_argument(

View file

@ -2,10 +2,13 @@ from __future__ import absolute_import
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
from typing import Optional
from typing import Sequence
from pre_commit_hooks.util import cmd_output from pre_commit_hooks.util import cmd_output
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
# `argv` is ignored, pre-commit will send us a list of files that we # `argv` is ignored, pre-commit will send us a list of files that we
# don't care about # don't care about
added_diff = cmd_output( added_diff = cmd_output(

View file

@ -4,6 +4,9 @@ from __future__ import unicode_literals
import argparse import argparse
import collections import collections
from typing import Dict
from typing import Optional
from typing import Sequence
CRLF = b'\r\n' CRLF = b'\r\n'
@ -14,7 +17,7 @@ ALL_ENDINGS = (CR, CRLF, LF)
FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF} FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF}
def _fix(filename, contents, ending): def _fix(filename, contents, ending): # type: (str, bytes, bytes) -> None
new_contents = b''.join( new_contents = b''.join(
line.rstrip(b'\r\n') + ending for line in contents.splitlines(True) line.rstrip(b'\r\n') + ending for line in contents.splitlines(True)
) )
@ -22,11 +25,11 @@ def _fix(filename, contents, ending):
f.write(new_contents) f.write(new_contents)
def fix_filename(filename, fix): def fix_filename(filename, fix): # type: (str, str) -> int
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
contents = f.read() contents = f.read()
counts = collections.defaultdict(int) counts = collections.defaultdict(int) # type: Dict[bytes, int]
for line in contents.splitlines(True): for line in contents.splitlines(True):
for ending in ALL_ENDINGS: for ending in ALL_ENDINGS:
@ -63,7 +66,7 @@ def fix_filename(filename, fix):
return other_endings return other_endings
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'-f', '--fix', '-f', '--fix',

View file

@ -1,12 +1,15 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
from typing import Optional
from typing import Sequence
from typing import Set
from pre_commit_hooks.util import CalledProcessError from pre_commit_hooks.util import CalledProcessError
from pre_commit_hooks.util import cmd_output from pre_commit_hooks.util import cmd_output
def is_on_branch(protected): def is_on_branch(protected): # type: (Set[str]) -> bool
try: try:
branch = cmd_output('git', 'symbolic-ref', 'HEAD') branch = cmd_output('git', 'symbolic-ref', 'HEAD')
except CalledProcessError: except CalledProcessError:
@ -15,7 +18,7 @@ def is_on_branch(protected):
return '/'.join(chunks[2:]) in protected return '/'.join(chunks[2:]) in protected
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'-b', '--branch', action='append', '-b', '--branch', action='append',

View file

@ -5,12 +5,20 @@ import io
import json import json
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from six import text_type from six import text_type
def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=[]): def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=()):
# type: (str, str, bool, bool, Sequence[str]) -> str
def pairs_first(pairs): def pairs_first(pairs):
# type: (Sequence[Tuple[str, str]]) -> Mapping[str, str]
before = [pair for pair in pairs if pair[0] in top_keys] before = [pair for pair in pairs if pair[0] in top_keys]
before = sorted(before, key=lambda x: top_keys.index(x[0])) before = sorted(before, key=lambda x: top_keys.index(x[0]))
after = [pair for pair in pairs if pair[0] not in top_keys] after = [pair for pair in pairs if pair[0] not in top_keys]
@ -27,13 +35,13 @@ def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_
return text_type(json_pretty) + '\n' return text_type(json_pretty) + '\n'
def _autofix(filename, new_contents): def _autofix(filename, new_contents): # type: (str, str) -> None
print('Fixing file {}'.format(filename)) print('Fixing file {}'.format(filename))
with io.open(filename, 'w', encoding='UTF-8') as f: with io.open(filename, 'w', encoding='UTF-8') as f:
f.write(new_contents) f.write(new_contents)
def parse_num_to_int(s): def parse_num_to_int(s): # type: (str) -> Union[int, str]
"""Convert string numbers to int, leaving strings as is.""" """Convert string numbers to int, leaving strings as is."""
try: try:
return int(s) return int(s)
@ -41,11 +49,11 @@ def parse_num_to_int(s):
return s return s
def parse_topkeys(s): def parse_topkeys(s): # type: (str) -> List[str]
return s.split(',') return s.split(',')
def pretty_format_json(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'--autofix', '--autofix',
@ -117,4 +125,4 @@ def pretty_format_json(argv=None):
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(pretty_format_json()) sys.exit(main())

View file

@ -1,6 +1,10 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
from typing import IO
from typing import List
from typing import Optional
from typing import Sequence
PASS = 0 PASS = 0
@ -9,21 +13,23 @@ FAIL = 1
class Requirement(object): class Requirement(object):
def __init__(self): def __init__(self): # type: () -> None
super(Requirement, self).__init__() super(Requirement, self).__init__()
self.value = None self.value = None # type: Optional[bytes]
self.comments = [] self.comments = [] # type: List[bytes]
@property @property
def name(self): def name(self): # type: () -> bytes
assert self.value is not None, self.value
if self.value.startswith(b'-e '): if self.value.startswith(b'-e '):
return self.value.lower().partition(b'=')[-1] return self.value.lower().partition(b'=')[-1]
return self.value.lower().partition(b'==')[0] return self.value.lower().partition(b'==')[0]
def __lt__(self, requirement): def __lt__(self, requirement): # type: (Requirement) -> int
# \n means top of file comment, so always return True, # \n means top of file comment, so always return True,
# otherwise just do a string comparison with value. # otherwise just do a string comparison with value.
assert self.value is not None, self.value
if self.value == b'\n': if self.value == b'\n':
return True return True
elif requirement.value == b'\n': elif requirement.value == b'\n':
@ -32,10 +38,10 @@ class Requirement(object):
return self.name < requirement.name return self.name < requirement.name
def fix_requirements(f): def fix_requirements(f): # type: (IO[bytes]) -> int
requirements = [] requirements = [] # type: List[Requirement]
before = tuple(f) before = tuple(f)
after = [] after = [] # type: List[bytes]
before_string = b''.join(before) before_string = b''.join(before)
@ -46,6 +52,7 @@ def fix_requirements(f):
for line in before: for line in before:
# If the most recent requirement object has a value, then it's # If the most recent requirement object has a value, then it's
# time to start building the next requirement object. # time to start building the next requirement object.
if not len(requirements) or requirements[-1].value is not None: if not len(requirements) or requirements[-1].value is not None:
requirements.append(Requirement()) requirements.append(Requirement())
@ -78,6 +85,7 @@ def fix_requirements(f):
for requirement in sorted(requirements): for requirement in sorted(requirements):
after.extend(requirement.comments) after.extend(requirement.comments)
assert requirement.value, requirement.value
after.append(requirement.value) after.append(requirement.value)
after.extend(rest) after.extend(rest)
@ -92,7 +100,7 @@ def fix_requirements(f):
return FAIL return FAIL
def fix_requirements_txt(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix') parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -109,3 +117,7 @@ def fix_requirements_txt(argv=None):
retv |= ret_for_file retv |= ret_for_file
return retv return retv
if __name__ == '__main__':
exit(main())

View file

@ -21,12 +21,15 @@ complicated YAML files.
from __future__ import print_function from __future__ import print_function
import argparse import argparse
from typing import List
from typing import Optional
from typing import Sequence
QUOTES = ["'", '"'] QUOTES = ["'", '"']
def sort(lines): def sort(lines): # type: (List[str]) -> List[str]
"""Sort a YAML file in alphabetical order, keeping blocks together. """Sort a YAML file in alphabetical order, keeping blocks together.
:param lines: array of strings (without newlines) :param lines: array of strings (without newlines)
@ -44,7 +47,7 @@ def sort(lines):
return new_lines return new_lines
def parse_block(lines, header=False): def parse_block(lines, header=False): # type: (List[str], bool) -> List[str]
"""Parse and return a single block, popping off the start of `lines`. """Parse and return a single block, popping off the start of `lines`.
If parsing a header block, we stop after we reach a line that is not a If parsing a header block, we stop after we reach a line that is not a
@ -60,7 +63,7 @@ def parse_block(lines, header=False):
return block_lines return block_lines
def parse_blocks(lines): def parse_blocks(lines): # type: (List[str]) -> List[List[str]]
"""Parse and return all possible blocks, popping off the start of `lines`. """Parse and return all possible blocks, popping off the start of `lines`.
:param lines: list of lines :param lines: list of lines
@ -77,7 +80,7 @@ def parse_blocks(lines):
return blocks return blocks
def first_key(lines): def first_key(lines): # type: (List[str]) -> str
"""Returns a string representing the sort key of a block. """Returns a string representing the sort key of a block.
The sort key is the first YAML key we encounter, ignoring comments, and The sort key is the first YAML key we encounter, ignoring comments, and
@ -95,9 +98,11 @@ def first_key(lines):
if any(line.startswith(quote) for quote in QUOTES): if any(line.startswith(quote) for quote in QUOTES):
return line[1:] return line[1:]
return line return line
else:
return '' # not actually reached in reality
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix') parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv) args = parser.parse_args(argv)

View file

@ -4,34 +4,39 @@ from __future__ import unicode_literals
import argparse import argparse
import io import io
import re
import tokenize import tokenize
from typing import List
from typing import Optional
from typing import Sequence
START_QUOTE_RE = re.compile('^[a-zA-Z]*"')
double_quote_starts = tuple(s for s in tokenize.single_quoted if '"' in s) def handle_match(token_text): # type: (str) -> str
def handle_match(token_text):
if '"""' in token_text or "'''" in token_text: if '"""' in token_text or "'''" in token_text:
return token_text return token_text
for double_quote_start in double_quote_starts: match = START_QUOTE_RE.match(token_text)
if token_text.startswith(double_quote_start): if match is not None:
meat = token_text[len(double_quote_start):-1] meat = token_text[match.end():-1]
if '"' in meat or "'" in meat: if '"' in meat or "'" in meat:
break return token_text
return double_quote_start.replace('"', "'") + meat + "'" else:
return token_text return match.group().replace('"', "'") + meat + "'"
else:
return token_text
def get_line_offsets_by_line_no(src): def get_line_offsets_by_line_no(src): # type: (str) -> List[int]
# Padded so we can index with line number # Padded so we can index with line number
offsets = [None, 0] offsets = [-1, 0]
for line in src.splitlines(): for line in src.splitlines():
offsets.append(offsets[-1] + len(line) + 1) offsets.append(offsets[-1] + len(line) + 1)
return offsets return offsets
def fix_strings(filename): def fix_strings(filename): # type: (str) -> int
with io.open(filename, encoding='UTF-8') as f: with io.open(filename, encoding='UTF-8') as f:
contents = f.read() contents = f.read()
line_offsets = get_line_offsets_by_line_no(contents) line_offsets = get_line_offsets_by_line_no(contents)
@ -60,7 +65,7 @@ def fix_strings(filename):
return 0 return 0
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix') parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -74,3 +79,7 @@ def main(argv=None):
retv |= return_value retv |= return_value
return retv return retv
if __name__ == '__main__':
exit(main())

View file

@ -1,12 +1,14 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os.path
import re import re
import sys import sys
from os.path import basename from typing import Optional
from typing import Sequence
def validate_files(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*') parser.add_argument('filenames', nargs='*')
parser.add_argument( parser.add_argument(
@ -18,7 +20,7 @@ def validate_files(argv=None):
retcode = 0 retcode = 0
test_name_pattern = 'test.*.py' if args.django else '.*_test.py' test_name_pattern = 'test.*.py' if args.django else '.*_test.py'
for filename in args.filenames: for filename in args.filenames:
base = basename(filename) base = os.path.basename(filename)
if ( if (
not re.match(test_name_pattern, base) and not re.match(test_name_pattern, base) and
not base == '__init__.py' and not base == '__init__.py' and
@ -35,4 +37,4 @@ def validate_files(argv=None):
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(validate_files()) sys.exit(main())

View file

@ -3,9 +3,11 @@ from __future__ import print_function
import argparse import argparse
import os import os
import sys import sys
from typing import Optional
from typing import Sequence
def _fix_file(filename, is_markdown): def _fix_file(filename, is_markdown): # type: (str, bool) -> bool
with open(filename, mode='rb') as file_processed: with open(filename, mode='rb') as file_processed:
lines = file_processed.readlines() lines = file_processed.readlines()
newlines = [_process_line(line, is_markdown) for line in lines] newlines = [_process_line(line, is_markdown) for line in lines]
@ -18,7 +20,7 @@ def _fix_file(filename, is_markdown):
return False return False
def _process_line(line, is_markdown): def _process_line(line, is_markdown): # type: (bytes, bool) -> bytes
if line[-2:] == b'\r\n': if line[-2:] == b'\r\n':
eol = b'\r\n' eol = b'\r\n'
elif line[-1:] == b'\n': elif line[-1:] == b'\n':
@ -31,7 +33,7 @@ def _process_line(line, is_markdown):
return line.rstrip() + eol return line.rstrip() + eol
def main(argv=None): def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'--no-markdown-linebreak-ext', '--no-markdown-linebreak-ext',

View file

@ -3,23 +3,25 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import subprocess import subprocess
from typing import Any
from typing import Set
class CalledProcessError(RuntimeError): class CalledProcessError(RuntimeError):
pass pass
def added_files(): def added_files(): # type: () -> Set[str]
return set(cmd_output( return set(cmd_output(
'git', 'diff', '--staged', '--name-only', '--diff-filter=A', 'git', 'diff', '--staged', '--name-only', '--diff-filter=A',
).splitlines()) ).splitlines())
def cmd_output(*cmd, **kwargs): def cmd_output(*cmd, **kwargs): # type: (*str, **Any) -> str
retcode = kwargs.pop('retcode', 0) retcode = kwargs.pop('retcode', 0)
popen_kwargs = {'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE} kwargs.setdefault('stdout', subprocess.PIPE)
popen_kwargs.update(kwargs) kwargs.setdefault('stderr', subprocess.PIPE)
proc = subprocess.Popen(cmd, **popen_kwargs) proc = subprocess.Popen(cmd, **kwargs)
stdout, stderr = proc.communicate() stdout, stderr = proc.communicate()
stdout = stdout.decode('UTF-8') stdout = stdout.decode('UTF-8')
if stderr is not None: if stderr is not None:

View file

@ -28,35 +28,36 @@ setup(
'ruamel.yaml>=0.15', 'ruamel.yaml>=0.15',
'six', 'six',
], ],
extras_require={':python_version<"3.5"': ['typing']},
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'autopep8-wrapper = pre_commit_hooks.autopep8_wrapper:main', 'autopep8-wrapper = pre_commit_hooks.autopep8_wrapper:main',
'check-added-large-files = pre_commit_hooks.check_added_large_files:main', 'check-added-large-files = pre_commit_hooks.check_added_large_files:main',
'check-ast = pre_commit_hooks.check_ast:check_ast', 'check-ast = pre_commit_hooks.check_ast:main',
'check-builtin-literals = pre_commit_hooks.check_builtin_literals:main', 'check-builtin-literals = pre_commit_hooks.check_builtin_literals:main',
'check-byte-order-marker = pre_commit_hooks.check_byte_order_marker:main', 'check-byte-order-marker = pre_commit_hooks.check_byte_order_marker:main',
'check-case-conflict = pre_commit_hooks.check_case_conflict:main', 'check-case-conflict = pre_commit_hooks.check_case_conflict:main',
'check-docstring-first = pre_commit_hooks.check_docstring_first:main', 'check-docstring-first = pre_commit_hooks.check_docstring_first:main',
'check-executables-have-shebangs = pre_commit_hooks.check_executables_have_shebangs:main', 'check-executables-have-shebangs = pre_commit_hooks.check_executables_have_shebangs:main',
'check-json = pre_commit_hooks.check_json:check_json', 'check-json = pre_commit_hooks.check_json:main',
'check-merge-conflict = pre_commit_hooks.check_merge_conflict:detect_merge_conflict', 'check-merge-conflict = pre_commit_hooks.check_merge_conflict:main',
'check-symlinks = pre_commit_hooks.check_symlinks:check_symlinks', 'check-symlinks = pre_commit_hooks.check_symlinks:main',
'check-vcs-permalinks = pre_commit_hooks.check_vcs_permalinks:main', 'check-vcs-permalinks = pre_commit_hooks.check_vcs_permalinks:main',
'check-xml = pre_commit_hooks.check_xml:check_xml', 'check-xml = pre_commit_hooks.check_xml:main',
'check-yaml = pre_commit_hooks.check_yaml:check_yaml', 'check-yaml = pre_commit_hooks.check_yaml:main',
'debug-statement-hook = pre_commit_hooks.debug_statement_hook:main', 'debug-statement-hook = pre_commit_hooks.debug_statement_hook:main',
'detect-aws-credentials = pre_commit_hooks.detect_aws_credentials:main', 'detect-aws-credentials = pre_commit_hooks.detect_aws_credentials:main',
'detect-private-key = pre_commit_hooks.detect_private_key:detect_private_key', 'detect-private-key = pre_commit_hooks.detect_private_key:main',
'double-quote-string-fixer = pre_commit_hooks.string_fixer:main', 'double-quote-string-fixer = pre_commit_hooks.string_fixer:main',
'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:end_of_file_fixer', 'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:main',
'file-contents-sorter = pre_commit_hooks.file_contents_sorter:main', 'file-contents-sorter = pre_commit_hooks.file_contents_sorter:main',
'fix-encoding-pragma = pre_commit_hooks.fix_encoding_pragma:main', 'fix-encoding-pragma = pre_commit_hooks.fix_encoding_pragma:main',
'forbid-new-submodules = pre_commit_hooks.forbid_new_submodules:main', 'forbid-new-submodules = pre_commit_hooks.forbid_new_submodules:main',
'mixed-line-ending = pre_commit_hooks.mixed_line_ending:main', 'mixed-line-ending = pre_commit_hooks.mixed_line_ending:main',
'name-tests-test = pre_commit_hooks.tests_should_end_in_test:validate_files', 'name-tests-test = pre_commit_hooks.tests_should_end_in_test:main',
'no-commit-to-branch = pre_commit_hooks.no_commit_to_branch:main', 'no-commit-to-branch = pre_commit_hooks.no_commit_to_branch:main',
'pretty-format-json = pre_commit_hooks.pretty_format_json:pretty_format_json', 'pretty-format-json = pre_commit_hooks.pretty_format_json:main',
'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:fix_requirements_txt', 'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:main',
'sort-simple-yaml = pre_commit_hooks.sort_simple_yaml:main', 'sort-simple-yaml = pre_commit_hooks.sort_simple_yaml:main',
'trailing-whitespace-fixer = pre_commit_hooks.trailing_whitespace_fixer:main', 'trailing-whitespace-fixer = pre_commit_hooks.trailing_whitespace_fixer:main',
], ],

0
testing/resources/bad_json_latin1.nonjson Executable file → Normal file
View file

View file

@ -1,17 +0,0 @@
from six.moves import builtins
c1 = complex()
d1 = dict()
f1 = float()
i1 = int()
l1 = list()
s1 = str()
t1 = tuple()
c2 = builtins.complex()
d2 = builtins.dict()
f2 = builtins.float()
i2 = builtins.int()
l2 = builtins.list()
s2 = builtins.str()
t2 = builtins.tuple()

View file

@ -1,7 +0,0 @@
c1 = 0j
d1 = {}
f1 = 0.0
i1 = 0
l1 = []
s1 = ''
t1 = ()

View file

@ -1,15 +1,15 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
from pre_commit_hooks.check_ast import check_ast from pre_commit_hooks.check_ast import main
from testing.util import get_resource_path from testing.util import get_resource_path
def test_failing_file(): def test_failing_file():
ret = check_ast([get_resource_path('cannot_parse_ast.notpy')]) ret = main([get_resource_path('cannot_parse_ast.notpy')])
assert ret == 1 assert ret == 1
def test_passing_file(): def test_passing_file():
ret = check_ast([__file__]) ret = main([__file__])
assert ret == 0 assert ret == 0

View file

@ -5,7 +5,35 @@ import pytest
from pre_commit_hooks.check_builtin_literals import BuiltinTypeCall from pre_commit_hooks.check_builtin_literals import BuiltinTypeCall
from pre_commit_hooks.check_builtin_literals import BuiltinTypeVisitor from pre_commit_hooks.check_builtin_literals import BuiltinTypeVisitor
from pre_commit_hooks.check_builtin_literals import main from pre_commit_hooks.check_builtin_literals import main
from testing.util import get_resource_path
BUILTIN_CONSTRUCTORS = '''\
from six.moves import builtins
c1 = complex()
d1 = dict()
f1 = float()
i1 = int()
l1 = list()
s1 = str()
t1 = tuple()
c2 = builtins.complex()
d2 = builtins.dict()
f2 = builtins.float()
i2 = builtins.int()
l2 = builtins.list()
s2 = builtins.str()
t2 = builtins.tuple()
'''
BUILTIN_LITERALS = '''\
c1 = 0j
d1 = {}
f1 = 0.0
i1 = 0
l1 = []
s1 = ''
t1 = ()
'''
@pytest.fixture @pytest.fixture
@ -94,24 +122,26 @@ def test_dict_no_allow_kwargs_exprs(expression, calls):
def test_ignore_constructors(): def test_ignore_constructors():
visitor = BuiltinTypeVisitor(ignore=('complex', 'dict', 'float', 'int', 'list', 'str', 'tuple')) visitor = BuiltinTypeVisitor(ignore=('complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'))
with open(get_resource_path('builtin_constructors.py'), 'rb') as f: visitor.visit(ast.parse(BUILTIN_CONSTRUCTORS))
visitor.visit(ast.parse(f.read(), 'builtin_constructors.py'))
assert visitor.builtin_type_calls == [] assert visitor.builtin_type_calls == []
def test_failing_file(): def test_failing_file(tmpdir):
rc = main([get_resource_path('builtin_constructors.py')]) f = tmpdir.join('f.py')
f.write(BUILTIN_CONSTRUCTORS)
rc = main([f.strpath])
assert rc == 1 assert rc == 1
def test_passing_file(): def test_passing_file(tmpdir):
rc = main([get_resource_path('builtin_literals.py')]) f = tmpdir.join('f.py')
f.write(BUILTIN_LITERALS)
rc = main([f.strpath])
assert rc == 0 assert rc == 0
def test_failing_file_ignore_all(): def test_failing_file_ignore_all(tmpdir):
rc = main([ f = tmpdir.join('f.py')
'--ignore=complex,dict,float,int,list,str,tuple', f.write(BUILTIN_CONSTRUCTORS)
get_resource_path('builtin_constructors.py'), rc = main(['--ignore=complex,dict,float,int,list,str,tuple', f.strpath])
])
assert rc == 0 assert rc == 0

View file

@ -1,6 +1,6 @@
import pytest import pytest
from pre_commit_hooks.check_json import check_json from pre_commit_hooks.check_json import main
from testing.util import get_resource_path from testing.util import get_resource_path
@ -11,8 +11,8 @@ from testing.util import get_resource_path
('ok_json.json', 0), ('ok_json.json', 0),
), ),
) )
def test_check_json(capsys, filename, expected_retval): def test_main(capsys, filename, expected_retval):
ret = check_json([get_resource_path(filename)]) ret = main([get_resource_path(filename)])
assert ret == expected_retval assert ret == expected_retval
if expected_retval == 1: if expected_retval == 1:
stdout, _ = capsys.readouterr() stdout, _ = capsys.readouterr()

View file

@ -6,7 +6,7 @@ import shutil
import pytest import pytest
from pre_commit_hooks.check_merge_conflict import detect_merge_conflict from pre_commit_hooks.check_merge_conflict import main
from pre_commit_hooks.util import cmd_output from pre_commit_hooks.util import cmd_output
from testing.util import get_resource_path from testing.util import get_resource_path
@ -102,7 +102,7 @@ def repository_pending_merge(tmpdir):
@pytest.mark.usefixtures('f1_is_a_conflict_file') @pytest.mark.usefixtures('f1_is_a_conflict_file')
def test_merge_conflicts_git(): def test_merge_conflicts_git():
assert detect_merge_conflict(['f1']) == 1 assert main(['f1']) == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -110,7 +110,7 @@ def test_merge_conflicts_git():
) )
def test_merge_conflicts_failing(contents, repository_pending_merge): def test_merge_conflicts_failing(contents, repository_pending_merge):
repository_pending_merge.join('f2').write_binary(contents) repository_pending_merge.join('f2').write_binary(contents)
assert detect_merge_conflict(['f2']) == 1 assert main(['f2']) == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -118,22 +118,22 @@ def test_merge_conflicts_failing(contents, repository_pending_merge):
) )
def test_merge_conflicts_ok(contents, f1_is_a_conflict_file): def test_merge_conflicts_ok(contents, f1_is_a_conflict_file):
f1_is_a_conflict_file.join('f1').write_binary(contents) f1_is_a_conflict_file.join('f1').write_binary(contents)
assert detect_merge_conflict(['f1']) == 0 assert main(['f1']) == 0
@pytest.mark.usefixtures('f1_is_a_conflict_file') @pytest.mark.usefixtures('f1_is_a_conflict_file')
def test_ignores_binary_files(): def test_ignores_binary_files():
shutil.copy(get_resource_path('img1.jpg'), 'f1') shutil.copy(get_resource_path('img1.jpg'), 'f1')
assert detect_merge_conflict(['f1']) == 0 assert main(['f1']) == 0
def test_does_not_care_when_not_in_a_merge(tmpdir): def test_does_not_care_when_not_in_a_merge(tmpdir):
f = tmpdir.join('README.md') f = tmpdir.join('README.md')
f.write_binary(b'problem\n=======\n') f.write_binary(b'problem\n=======\n')
assert detect_merge_conflict([str(f.realpath())]) == 0 assert main([str(f.realpath())]) == 0
def test_care_when_assumed_merge(tmpdir): def test_care_when_assumed_merge(tmpdir):
f = tmpdir.join('README.md') f = tmpdir.join('README.md')
f.write_binary(b'problem\n=======\n') f.write_binary(b'problem\n=======\n')
assert detect_merge_conflict([str(f.realpath()), '--assume-in-merge']) == 1 assert main([str(f.realpath()), '--assume-in-merge']) == 1

View file

@ -2,7 +2,7 @@ import os
import pytest import pytest
from pre_commit_hooks.check_symlinks import check_symlinks from pre_commit_hooks.check_symlinks import main
xfail_symlink = pytest.mark.xfail(os.name == 'nt', reason='No symlink support') xfail_symlink = pytest.mark.xfail(os.name == 'nt', reason='No symlink support')
@ -12,12 +12,12 @@ xfail_symlink = pytest.mark.xfail(os.name == 'nt', reason='No symlink support')
@pytest.mark.parametrize( @pytest.mark.parametrize(
('dest', 'expected'), (('exists', 0), ('does-not-exist', 1)), ('dest', 'expected'), (('exists', 0), ('does-not-exist', 1)),
) )
def test_check_symlinks(tmpdir, dest, expected): # pragma: no cover (symlinks) def test_main(tmpdir, dest, expected): # pragma: no cover (symlinks)
tmpdir.join('exists').ensure() tmpdir.join('exists').ensure()
symlink = tmpdir.join('symlink') symlink = tmpdir.join('symlink')
symlink.mksymlinkto(tmpdir.join(dest)) symlink.mksymlinkto(tmpdir.join(dest))
assert check_symlinks((symlink.strpath,)) == expected assert main((symlink.strpath,)) == expected
def test_check_symlinks_normal_file(tmpdir): def test_main_normal_file(tmpdir):
assert check_symlinks((tmpdir.join('f').ensure().strpath,)) == 0 assert main((tmpdir.join('f').ensure().strpath,)) == 0

View file

@ -1,6 +1,6 @@
import pytest import pytest
from pre_commit_hooks.check_xml import check_xml from pre_commit_hooks.check_xml import main
from testing.util import get_resource_path from testing.util import get_resource_path
@ -10,6 +10,6 @@ from testing.util import get_resource_path
('ok_xml.xml', 0), ('ok_xml.xml', 0),
), ),
) )
def test_check_xml(filename, expected_retval): def test_main(filename, expected_retval):
ret = check_xml([get_resource_path(filename)]) ret = main([get_resource_path(filename)])
assert ret == expected_retval assert ret == expected_retval

View file

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import pytest import pytest
from pre_commit_hooks.check_yaml import check_yaml from pre_commit_hooks.check_yaml import main
from testing.util import get_resource_path from testing.util import get_resource_path
@ -13,29 +13,29 @@ from testing.util import get_resource_path
('ok_yaml.yaml', 0), ('ok_yaml.yaml', 0),
), ),
) )
def test_check_yaml(filename, expected_retval): def test_main(filename, expected_retval):
ret = check_yaml([get_resource_path(filename)]) ret = main([get_resource_path(filename)])
assert ret == expected_retval assert ret == expected_retval
def test_check_yaml_allow_multiple_documents(tmpdir): def test_main_allow_multiple_documents(tmpdir):
f = tmpdir.join('test.yaml') f = tmpdir.join('test.yaml')
f.write('---\nfoo\n---\nbar\n') f.write('---\nfoo\n---\nbar\n')
# should fail without the setting # should fail without the setting
assert check_yaml((f.strpath,)) assert main((f.strpath,))
# should pass when we allow multiple documents # should pass when we allow multiple documents
assert not check_yaml(('--allow-multiple-documents', f.strpath)) assert not main(('--allow-multiple-documents', f.strpath))
def test_fails_even_with_allow_multiple_documents(tmpdir): def test_fails_even_with_allow_multiple_documents(tmpdir):
f = tmpdir.join('test.yaml') f = tmpdir.join('test.yaml')
f.write('[') f.write('[')
assert check_yaml(('--allow-multiple-documents', f.strpath)) assert main(('--allow-multiple-documents', f.strpath))
def test_check_yaml_unsafe(tmpdir): def test_main_unsafe(tmpdir):
f = tmpdir.join('test.yaml') f = tmpdir.join('test.yaml')
f.write( f.write(
'some_foo: !vault |\n' 'some_foo: !vault |\n'
@ -43,12 +43,12 @@ def test_check_yaml_unsafe(tmpdir):
' deadbeefdeadbeefdeadbeef\n', ' deadbeefdeadbeefdeadbeef\n',
) )
# should fail "safe" check # should fail "safe" check
assert check_yaml((f.strpath,)) assert main((f.strpath,))
# should pass when we allow unsafe documents # should pass when we allow unsafe documents
assert not check_yaml(('--unsafe', f.strpath)) assert not main(('--unsafe', f.strpath))
def test_check_yaml_unsafe_still_fails_on_syntax_errors(tmpdir): def test_main_unsafe_still_fails_on_syntax_errors(tmpdir):
f = tmpdir.join('test.yaml') f = tmpdir.join('test.yaml')
f.write('[') f.write('[')
assert check_yaml(('--unsafe', f.strpath)) assert main(('--unsafe', f.strpath))

View file

@ -1,6 +1,6 @@
import pytest import pytest
from pre_commit_hooks.detect_private_key import detect_private_key from pre_commit_hooks.detect_private_key import main
# Input, expected return value # Input, expected return value
TESTS = ( TESTS = (
@ -18,7 +18,7 @@ TESTS = (
@pytest.mark.parametrize(('input_s', 'expected_retval'), TESTS) @pytest.mark.parametrize(('input_s', 'expected_retval'), TESTS)
def test_detect_private_key(input_s, expected_retval, tmpdir): def test_main(input_s, expected_retval, tmpdir):
path = tmpdir.join('file.txt') path = tmpdir.join('file.txt')
path.write_binary(input_s) path.write_binary(input_s)
assert detect_private_key([path.strpath]) == expected_retval assert main([path.strpath]) == expected_retval

View file

@ -2,8 +2,8 @@ import io
import pytest import pytest
from pre_commit_hooks.end_of_file_fixer import end_of_file_fixer
from pre_commit_hooks.end_of_file_fixer import fix_file from pre_commit_hooks.end_of_file_fixer import fix_file
from pre_commit_hooks.end_of_file_fixer import main
# Input, expected return value, expected output # Input, expected return value, expected output
@ -35,7 +35,7 @@ def test_integration(input_s, expected_retval, output, tmpdir):
path = tmpdir.join('file.txt') path = tmpdir.join('file.txt')
path.write_binary(input_s) path.write_binary(input_s)
ret = end_of_file_fixer([path.strpath]) ret = main([path.strpath])
file_output = path.read_binary() file_output = path.read_binary()
assert file_output == output assert file_output == output

View file

@ -11,24 +11,24 @@ from pre_commit_hooks.util import cmd_output
def test_other_branch(temp_git_dir): def test_other_branch(temp_git_dir):
with temp_git_dir.as_cwd(): with temp_git_dir.as_cwd():
cmd_output('git', 'checkout', '-b', 'anotherbranch') cmd_output('git', 'checkout', '-b', 'anotherbranch')
assert is_on_branch(('master',)) is False assert is_on_branch({'master'}) is False
def test_multi_branch(temp_git_dir): def test_multi_branch(temp_git_dir):
with temp_git_dir.as_cwd(): with temp_git_dir.as_cwd():
cmd_output('git', 'checkout', '-b', 'another/branch') cmd_output('git', 'checkout', '-b', 'another/branch')
assert is_on_branch(('master',)) is False assert is_on_branch({'master'}) is False
def test_multi_branch_fail(temp_git_dir): def test_multi_branch_fail(temp_git_dir):
with temp_git_dir.as_cwd(): with temp_git_dir.as_cwd():
cmd_output('git', 'checkout', '-b', 'another/branch') cmd_output('git', 'checkout', '-b', 'another/branch')
assert is_on_branch(('another/branch',)) is True assert is_on_branch({'another/branch'}) is True
def test_master_branch(temp_git_dir): def test_master_branch(temp_git_dir):
with temp_git_dir.as_cwd(): with temp_git_dir.as_cwd():
assert is_on_branch(('master',)) is True assert is_on_branch({'master'}) is True
def test_main_branch_call(temp_git_dir): def test_main_branch_call(temp_git_dir):

View file

@ -3,8 +3,8 @@ import shutil
import pytest import pytest
from six import PY2 from six import PY2
from pre_commit_hooks.pretty_format_json import main
from pre_commit_hooks.pretty_format_json import parse_num_to_int from pre_commit_hooks.pretty_format_json import parse_num_to_int
from pre_commit_hooks.pretty_format_json import pretty_format_json
from testing.util import get_resource_path from testing.util import get_resource_path
@ -23,8 +23,8 @@ def test_parse_num_to_int():
('pretty_formatted_json.json', 0), ('pretty_formatted_json.json', 0),
), ),
) )
def test_pretty_format_json(filename, expected_retval): def test_main(filename, expected_retval):
ret = pretty_format_json([get_resource_path(filename)]) ret = main([get_resource_path(filename)])
assert ret == expected_retval assert ret == expected_retval
@ -36,8 +36,8 @@ def test_pretty_format_json(filename, expected_retval):
('pretty_formatted_json.json', 0), ('pretty_formatted_json.json', 0),
), ),
) )
def test_unsorted_pretty_format_json(filename, expected_retval): def test_unsorted_main(filename, expected_retval):
ret = pretty_format_json(['--no-sort-keys', get_resource_path(filename)]) ret = main(['--no-sort-keys', get_resource_path(filename)])
assert ret == expected_retval assert ret == expected_retval
@ -51,17 +51,17 @@ def test_unsorted_pretty_format_json(filename, expected_retval):
('tab_pretty_formatted_json.json', 0), ('tab_pretty_formatted_json.json', 0),
), ),
) )
def test_tab_pretty_format_json(filename, expected_retval): # pragma: no cover def test_tab_main(filename, expected_retval): # pragma: no cover
ret = pretty_format_json(['--indent', '\t', get_resource_path(filename)]) ret = main(['--indent', '\t', get_resource_path(filename)])
assert ret == expected_retval assert ret == expected_retval
def test_non_ascii_pretty_format_json(): def test_non_ascii_main():
ret = pretty_format_json(['--no-ensure-ascii', get_resource_path('non_ascii_pretty_formatted_json.json')]) ret = main(['--no-ensure-ascii', get_resource_path('non_ascii_pretty_formatted_json.json')])
assert ret == 0 assert ret == 0
def test_autofix_pretty_format_json(tmpdir): def test_autofix_main(tmpdir):
srcfile = tmpdir.join('to_be_json_formatted.json') srcfile = tmpdir.join('to_be_json_formatted.json')
shutil.copyfile( shutil.copyfile(
get_resource_path('not_pretty_formatted_json.json'), get_resource_path('not_pretty_formatted_json.json'),
@ -69,30 +69,30 @@ def test_autofix_pretty_format_json(tmpdir):
) )
# now launch the autofix on that file # now launch the autofix on that file
ret = pretty_format_json(['--autofix', srcfile.strpath]) ret = main(['--autofix', srcfile.strpath])
# it should have formatted it # it should have formatted it
assert ret == 1 assert ret == 1
# file was formatted (shouldn't trigger linter again) # file was formatted (shouldn't trigger linter again)
ret = pretty_format_json([srcfile.strpath]) ret = main([srcfile.strpath])
assert ret == 0 assert ret == 0
def test_orderfile_get_pretty_format(): def test_orderfile_get_pretty_format():
ret = pretty_format_json(['--top-keys=alist', get_resource_path('pretty_formatted_json.json')]) ret = main(['--top-keys=alist', get_resource_path('pretty_formatted_json.json')])
assert ret == 0 assert ret == 0
def test_not_orderfile_get_pretty_format(): def test_not_orderfile_get_pretty_format():
ret = pretty_format_json(['--top-keys=blah', get_resource_path('pretty_formatted_json.json')]) ret = main(['--top-keys=blah', get_resource_path('pretty_formatted_json.json')])
assert ret == 1 assert ret == 1
def test_top_sorted_get_pretty_format(): def test_top_sorted_get_pretty_format():
ret = pretty_format_json(['--top-keys=01-alist,alist', get_resource_path('top_sorted_json.json')]) ret = main(['--top-keys=01-alist,alist', get_resource_path('top_sorted_json.json')])
assert ret == 0 assert ret == 0
def test_badfile_pretty_format_json(): def test_badfile_main():
ret = pretty_format_json([get_resource_path('ok_yaml.yaml')]) ret = main([get_resource_path('ok_yaml.yaml')])
assert ret == 1 assert ret == 1

View file

@ -1,7 +1,7 @@
import pytest import pytest
from pre_commit_hooks.requirements_txt_fixer import FAIL from pre_commit_hooks.requirements_txt_fixer import FAIL
from pre_commit_hooks.requirements_txt_fixer import fix_requirements_txt from pre_commit_hooks.requirements_txt_fixer import main
from pre_commit_hooks.requirements_txt_fixer import PASS from pre_commit_hooks.requirements_txt_fixer import PASS
from pre_commit_hooks.requirements_txt_fixer import Requirement from pre_commit_hooks.requirements_txt_fixer import Requirement
@ -36,7 +36,7 @@ def test_integration(input_s, expected_retval, output, tmpdir):
path = tmpdir.join('file.txt') path = tmpdir.join('file.txt')
path.write_binary(input_s) path.write_binary(input_s)
output_retval = fix_requirements_txt([path.strpath]) output_retval = main([path.strpath])
assert path.read_binary() == output assert path.read_binary() == output
assert output_retval == expected_retval assert output_retval == expected_retval
@ -44,7 +44,7 @@ def test_integration(input_s, expected_retval, output, tmpdir):
def test_requirement_object(): def test_requirement_object():
top_of_file = Requirement() top_of_file = Requirement()
top_of_file.comments.append('#foo') top_of_file.comments.append(b'#foo')
top_of_file.value = b'\n' top_of_file.value = b'\n'
requirement_foo = Requirement() requirement_foo = Requirement()

View file

@ -110,9 +110,9 @@ def test_first_key():
lines = ['# some comment', '"a": 42', 'b: 17', '', 'c: 19'] lines = ['# some comment', '"a": 42', 'b: 17', '', 'c: 19']
assert first_key(lines) == 'a": 42' assert first_key(lines) == 'a": 42'
# no lines # no lines (not a real situation)
lines = [] lines = []
assert first_key(lines) is None assert first_key(lines) == ''
@pytest.mark.parametrize('bad_lines,good_lines,_', TEST_SORTS) @pytest.mark.parametrize('bad_lines,good_lines,_', TEST_SORTS)

View file

@ -1,36 +1,36 @@
from pre_commit_hooks.tests_should_end_in_test import validate_files from pre_commit_hooks.tests_should_end_in_test import main
def test_validate_files_all_pass(): def test_main_all_pass():
ret = validate_files(['foo_test.py', 'bar_test.py']) ret = main(['foo_test.py', 'bar_test.py'])
assert ret == 0 assert ret == 0
def test_validate_files_one_fails(): def test_main_one_fails():
ret = validate_files(['not_test_ending.py', 'foo_test.py']) ret = main(['not_test_ending.py', 'foo_test.py'])
assert ret == 1 assert ret == 1
def test_validate_files_django_all_pass(): def test_main_django_all_pass():
ret = validate_files(['--django', 'tests.py', 'test_foo.py', 'test_bar.py', 'tests/test_baz.py']) ret = main(['--django', 'tests.py', 'test_foo.py', 'test_bar.py', 'tests/test_baz.py'])
assert ret == 0 assert ret == 0
def test_validate_files_django_one_fails(): def test_main_django_one_fails():
ret = validate_files(['--django', 'not_test_ending.py', 'test_foo.py']) ret = main(['--django', 'not_test_ending.py', 'test_foo.py'])
assert ret == 1 assert ret == 1
def test_validate_nested_files_django_one_fails(): def test_validate_nested_files_django_one_fails():
ret = validate_files(['--django', 'tests/not_test_ending.py', 'test_foo.py']) ret = main(['--django', 'tests/not_test_ending.py', 'test_foo.py'])
assert ret == 1 assert ret == 1
def test_validate_files_not_django_fails(): def test_main_not_django_fails():
ret = validate_files(['foo_test.py', 'bar_test.py', 'test_baz.py']) ret = main(['foo_test.py', 'bar_test.py', 'test_baz.py'])
assert ret == 1 assert ret == 1
def test_validate_files_django_fails(): def test_main_django_fails():
ret = validate_files(['--django', 'foo_test.py', 'test_bar.py', 'test_baz.py']) ret = main(['--django', 'foo_test.py', 'test_bar.py', 'test_baz.py'])
assert ret == 1 assert ret == 1

View file

@ -1,6 +1,6 @@
[tox] [tox]
# These should match the travis env list # These should match the travis env list
envlist = py27,py36,py37,pypy envlist = py27,py36,py37,pypy3
[testenv] [testenv]
deps = -rrequirements-dev.txt deps = -rrequirements-dev.txt