From 693709e7610eb967eb6a24aaf368473b07d1e956 Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Thu, 11 Aug 2016 22:56:54 -0700 Subject: [PATCH] Allow encoding pragma to be customizable --- pre_commit_hooks/fix_encoding_pragma.py | 34 +++++++++++++++---- tests/fix_encoding_pragma_test.py | 44 +++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py index 8586937..5dcff93 100644 --- a/pre_commit_hooks/fix_encoding_pragma.py +++ b/pre_commit_hooks/fix_encoding_pragma.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals import argparse import collections -expected_pragma = b'# -*- coding: utf-8 -*-\n' +DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n' def has_coding(line): @@ -41,7 +41,7 @@ class ExpectedContents(collections.namedtuple( return self.pragma_status is expected_pragma_status -def _get_expected_contents(first_line, second_line, rest): +def _get_expected_contents(first_line, second_line, rest, expected_pragma): if first_line.startswith(b'#!'): shebang = first_line potential_coding = second_line @@ -63,8 +63,10 @@ def _get_expected_contents(first_line, second_line, rest): ) -def fix_encoding_pragma(f, remove=False): - expected = _get_expected_contents(f.readline(), f.readline(), f.read()) +def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA): + expected = _get_expected_contents( + f.readline(), f.readline(), f.read(), expected_pragma, + ) # Special cases for empty files if not expected.rest.strip(): @@ -91,9 +93,25 @@ def fix_encoding_pragma(f, remove=False): return 1 +def _normalize_pragma(pragma): + if not isinstance(pragma, bytes): + pragma = pragma.encode('UTF-8') + return pragma.rstrip() + b'\n' + + +def _to_disp(pragma): + return pragma.decode().rstrip() + + def main(argv=None): parser = argparse.ArgumentParser('Fixes the encoding pragma of python files') parser.add_argument('filenames', nargs='*', help='Filenames to fix') + parser.add_argument( + '--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma, + help='The encoding pragma to use. Default: {}'.format( + _to_disp(DEFAULT_PRAGMA), + ), + ) parser.add_argument( '--remove', action='store_true', help='Remove the encoding pragma (Useful in a python3-only codebase)', @@ -109,10 +127,14 @@ def main(argv=None): for filename in args.filenames: with open(filename, 'r+b') as f: - file_ret = fix_encoding_pragma(f, remove=args.remove) + file_ret = fix_encoding_pragma( + f, remove=args.remove, expected_pragma=args.pragma, + ) retv |= file_ret if file_ret: - print(fmt.format(pragma=expected_pragma, filename=filename)) + print(fmt.format( + pragma=_to_disp(args.pragma), filename=filename, + )) return retv diff --git a/tests/fix_encoding_pragma_test.py b/tests/fix_encoding_pragma_test.py index a9502a2..d49f1ba 100644 --- a/tests/fix_encoding_pragma_test.py +++ b/tests/fix_encoding_pragma_test.py @@ -5,6 +5,7 @@ import io import pytest +from pre_commit_hooks.fix_encoding_pragma import _normalize_pragma from pre_commit_hooks.fix_encoding_pragma import fix_encoding_pragma from pre_commit_hooks.fix_encoding_pragma import main @@ -106,3 +107,46 @@ def test_not_ok_inputs(input_str, output): assert fix_encoding_pragma(bytesio) == 1 bytesio.seek(0) assert bytesio.read() == output + + +def test_ok_input_alternate_pragma(): + input_s = b'# coding: utf-8\nx = 1\n' + bytesio = io.BytesIO(input_s) + ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n') + assert ret == 0 + bytesio.seek(0) + assert bytesio.read() == input_s + + +def test_not_ok_input_alternate_pragma(): + bytesio = io.BytesIO(b'x = 1\n') + ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n') + assert ret == 1 + bytesio.seek(0) + assert bytesio.read() == b'# coding: utf-8\nx = 1\n' + + +@pytest.mark.parametrize( + ('input_s', 'expected'), + ( + # Python 2 cli parameters are bytes + (b'# coding: utf-8', b'# coding: utf-8\n'), + # Python 3 cli parameters are text + ('# coding: utf-8', b'# coding: utf-8\n'), + # trailing whitespace + ('# coding: utf-8\n', b'# coding: utf-8\n'), + ), +) +def test_normalize_pragma(input_s, expected): + assert _normalize_pragma(input_s) == expected + + +def test_integration_alternate_pragma(tmpdir, capsys): + f = tmpdir.join('f.py') + f.write('x = 1\n') + + pragma = '# coding: utf-8' + assert main((f.strpath, '--pragma', pragma)) == 1 + assert f.read() == '# coding: utf-8\nx = 1\n' + out, _ = capsys.readouterr() + assert out == 'Added `# coding: utf-8` to {}\n'.format(f.strpath)