diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py index 7e6be95..7f64a23 100644 --- a/pre_commit_hooks/debug_statement_hook.py +++ b/pre_commit_hooks/debug_statement_hook.py @@ -20,6 +20,11 @@ DEBUG_STATEMENTS = { 'wdb', } +DEBUG_CALL_STATEMENTS = { + 'breakpoint', + 'print' +} + class Debug(NamedTuple): line: int @@ -45,7 +50,7 @@ class DebugStatementParser(ast.NodeVisitor): def visit_Call(self, node: ast.Call) -> None: """python3.7+ breakpoint()""" - if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint': + if isinstance(node.func, ast.Name) and node.func.id in DEBUG_CALL_STATEMENTS: st = Debug(node.lineno, node.col_offset, node.func.id, 'called') self.breakpoints.append(st) self.generic_visit(node) diff --git a/tests/debug_statement_hook_test.py b/tests/debug_statement_hook_test.py index 5a8e0bb..a301fbf 100644 --- a/tests/debug_statement_hook_test.py +++ b/tests/debug_statement_hook_test.py @@ -30,6 +30,12 @@ def test_finds_breakpoint(): visitor = DebugStatementParser() visitor.visit(ast.parse('breakpoint()')) assert visitor.breakpoints == [Debug(1, 0, 'breakpoint', 'called')] + + +def test_finds_print(): + visitor = DebugStatementParser() + visitor.visit(ast.parse('print()')) + assert visitor.breakpoints == [Debug(1, 0, 'print', 'called')] def test_returns_one_for_failing_file(tmpdir):