diff --git a/.gitignore b/.gitignore index 4f6c5b7..c25bf32 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ .coverage .tox dist +.idea diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py index 0ef9bc7..a4980f7 100644 --- a/pre_commit_hooks/string_fixer.py +++ b/pre_commit_hooks/string_fixer.py @@ -1,57 +1,17 @@ from __future__ import annotations import argparse -import io -import re -import tokenize from typing import Sequence -START_QUOTE_RE = re.compile('^[a-zA-Z]*"') - - -def handle_match(token_text: str) -> str: - if '"""' in token_text or "'''" in token_text: - return token_text - - match = START_QUOTE_RE.match(token_text) - if match is not None: - meat = token_text[match.end():-1] - if '"' in meat or "'" in meat: - return token_text - else: - return match.group().replace('"', "'") + meat + "'" - else: - return token_text - - -def get_line_offsets_by_line_no(src: str) -> list[int]: - # Padded so we can index with line number - offsets = [-1, 0] - for line in src.splitlines(True): - offsets.append(offsets[-1] + len(line)) - return offsets +from pre_commit_hooks.util_string_fixer import fix_strings_in_file_contents def fix_strings(filename: str) -> int: with open(filename, encoding='UTF-8', newline='') as f: contents = f.read() - line_offsets = get_line_offsets_by_line_no(contents) - # Basically a mutable string - splitcontents = list(contents) + new_contents = fix_strings_in_file_contents(contents) - # Iterate in reverse so the offsets are always correct - tokens_l = list(tokenize.generate_tokens(io.StringIO(contents).readline)) - tokens = reversed(tokens_l) - for token_type, token_text, (srow, scol), (erow, ecol), _ in tokens: - if token_type == tokenize.STRING: - new_text = handle_match(token_text) - splitcontents[ - line_offsets[srow] + scol: - line_offsets[erow] + ecol - ] = new_text - - new_contents = ''.join(splitcontents) if contents != new_contents: with open(filename, 'w', encoding='UTF-8', newline='') as f: f.write(new_contents) diff --git a/pre_commit_hooks/string_fixer_for_jupyter_notebooks.py b/pre_commit_hooks/string_fixer_for_jupyter_notebooks.py new file mode 100644 index 0000000..6f1de07 --- /dev/null +++ b/pre_commit_hooks/string_fixer_for_jupyter_notebooks.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import argparse +import json +from typing import Sequence + +from pre_commit_hooks.util_string_fixer import fix_strings_in_file_contents + + +def fix_strings(filename: str) -> int: + with open(filename) as f: + notebook_contents = json.load(f) + + cells = notebook_contents['cells'] + return_value = 0 + for cell in cells: + if cell['cell_type'] == 'code': + source_in_1_line = ''.join(cell['source']) + fixed = fix_strings_in_file_contents(source_in_1_line) + if fixed != source_in_1_line: + fixed_lines = fixed.split('\n') + cell['source'] = [_ + '\n' for _ in fixed_lines[:-1]] + [fixed_lines[-1]] + return_value = 1 + + if return_value == 1: + notebook_contents['cells'] = cells + with open(filename, 'w') as f: + json.dump(notebook_contents, f, indent=1) + f.write("\n") # because json.dump doesn't put \n at the end + + return return_value + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + retv = 0 + + for filename in args.filenames: + return_value = fix_strings(filename) + if return_value != 0: + print(f'Fixing strings in {filename}') + retv |= return_value + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/util_string_fixer.py b/pre_commit_hooks/util_string_fixer.py new file mode 100644 index 0000000..0460127 --- /dev/null +++ b/pre_commit_hooks/util_string_fixer.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import io +import re +import tokenize + + +def handle_match(token_text: str) -> str: + if '"""' in token_text or "'''" in token_text: + return token_text + + match = START_QUOTE_RE.match(token_text) + if match is not None: + meat = token_text[match.end():-1] + if '"' in meat or "'" in meat: + return token_text + else: + return match.group().replace('"', "'") + meat + "'" + else: + return token_text + + +def get_line_offsets_by_line_no(src: str) -> list[int]: + # Padded so we can index with line number + offsets = [-1, 0] + for line in src.splitlines(True): + offsets.append(offsets[-1] + len(line)) + return offsets + + +def fix_strings_in_file_contents(contents: str) -> str: + line_offsets = get_line_offsets_by_line_no(contents) + + # Basically a mutable string + splitcontents = list(contents) + + # Iterate in reverse so the offsets are always correct + tokens_l = list(tokenize.generate_tokens(io.StringIO(contents).readline)) + tokens = reversed(tokens_l) + for token_type, token_text, (srow, scol), (erow, ecol), _ in tokens: + if token_type == tokenize.STRING: + new_text = handle_match(token_text) + splitcontents[ + line_offsets[srow] + scol: + line_offsets[erow] + ecol + ] = new_text + + new_contents = ''.join(splitcontents) + return new_contents + + +START_QUOTE_RE = re.compile('^[a-zA-Z]*"') diff --git a/testing/resources/after.ipynb b/testing/resources/after.ipynb new file mode 100644 index 0000000..f5914f1 --- /dev/null +++ b/testing/resources/after.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f3d342a9", + "metadata": {}, + "source": [ + "Multi-line Python cell without strings, just to make sure notebook cells are not accidentally changed:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ebb72f62", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1\n", + "b = 2\n", + "c = a + b" + ] + }, + { + "cell_type": "markdown", + "id": "9ccaec06", + "metadata": {}, + "source": [ + "Multi-line Python cell with a string to fix, just to make sure notebook cells can be correctly reconstructed:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1fbe0c01", + "metadata": {}, + "outputs": [], + "source": [ + "d = 'hello'\n", + "e = 'world'\n", + "f = \"hello 'world'\"" + ] + }, + { + "cell_type": "markdown", + "id": "0ae6d462", + "metadata": {}, + "source": [ + "Base cases:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "80fcff95", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\'\"\"\"'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'' + '' + \"\\'\" + \"\\\"\" + '\\\"\\\"'" + ] + }, + { + "cell_type": "markdown", + "id": "6556ef23", + "metadata": {}, + "source": [ + "String somewhere in the line:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7f676729", + "metadata": {}, + "outputs": [], + "source": [ + "x = 'foo'" + ] + }, + { + "cell_type": "markdown", + "id": "f19cd463", + "metadata": {}, + "source": [ + "Test escaped characters:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "717cab5d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"'\"" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"\\'\"" + ] + }, + { + "cell_type": "markdown", + "id": "d70877ff", + "metadata": {}, + "source": [ + "Docstring" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0bddf826", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "' Foo '" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"\"\" Foo \"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3b42b721", + "metadata": {}, + "outputs": [], + "source": [ + "x = ' \\\n", + "foo \\\n", + "'" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "aa58b482", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'foobar'" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'foo''bar'" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/testing/resources/before.ipynb b/testing/resources/before.ipynb new file mode 100644 index 0000000..32a617d --- /dev/null +++ b/testing/resources/before.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f3d342a9", + "metadata": {}, + "source": [ + "Multi-line Python cell without strings, just to make sure notebook cells are not accidentally changed:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ebb72f62", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1\n", + "b = 2\n", + "c = a + b" + ] + }, + { + "cell_type": "markdown", + "id": "9ccaec06", + "metadata": {}, + "source": [ + "Multi-line Python cell with a string to fix, just to make sure notebook cells can be correctly reconstructed:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1fbe0c01", + "metadata": {}, + "outputs": [], + "source": [ + "d = \"hello\"\n", + "e = \"world\"\n", + "f = \"hello 'world'\"" + ] + }, + { + "cell_type": "markdown", + "id": "0ae6d462", + "metadata": {}, + "source": [ + "Base cases:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "80fcff95", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\'\"\"\"'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'' + \"\" + \"\\'\" + \"\\\"\" + '\\\"\\\"'" + ] + }, + { + "cell_type": "markdown", + "id": "6556ef23", + "metadata": {}, + "source": [ + "String somewhere in the line:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7f676729", + "metadata": {}, + "outputs": [], + "source": [ + "x = \"foo\"" + ] + }, + { + "cell_type": "markdown", + "id": "f19cd463", + "metadata": {}, + "source": [ + "Test escaped characters:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "717cab5d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"'\"" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"\\'\"" + ] + }, + { + "cell_type": "markdown", + "id": "d70877ff", + "metadata": {}, + "source": [ + "Docstring" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0bddf826", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "' Foo '" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"\"\" Foo \"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3b42b721", + "metadata": {}, + "outputs": [], + "source": [ + "x = \" \\\n", + "foo \\\n", + "\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "aa58b482", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'foobar'" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"foo\"\"bar\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/string_fixer_for_jupyter_notebooks_test.py b/tests/string_fixer_for_jupyter_notebooks_test.py new file mode 100644 index 0000000..956b71b --- /dev/null +++ b/tests/string_fixer_for_jupyter_notebooks_test.py @@ -0,0 +1,16 @@ +from pre_commit_hooks.string_fixer_for_jupyter_notebooks import main +from testing.util import get_resource_path + + +def test_rewrite(tmpdir): + with open(get_resource_path('before.ipynb')) as f: + before_contents = f.read() + + with open(get_resource_path('after.ipynb')) as f: + after_contents = f.read() + + path = tmpdir.join('file.ipynb') + path.write(before_contents) + retval = main([str(path)]) + assert path.read() == after_contents + assert retval == 1