check for print debug statement

This commit is contained in:
Aaditya Subedi 2025-11-19 12:57:00 +05:45
parent f4e025486b
commit 7c3f665ba8
2 changed files with 12 additions and 1 deletions

View file

@ -20,6 +20,11 @@ DEBUG_STATEMENTS = {
'wdb', 'wdb',
} }
DEBUG_CALL_STATEMENTS = {
'breakpoint',
'print'
}
class Debug(NamedTuple): class Debug(NamedTuple):
line: int line: int
@ -45,7 +50,7 @@ class DebugStatementParser(ast.NodeVisitor):
def visit_Call(self, node: ast.Call) -> None: def visit_Call(self, node: ast.Call) -> None:
"""python3.7+ breakpoint()""" """python3.7+ breakpoint()"""
if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint': if isinstance(node.func, ast.Name) and node.func.id in DEBUG_CALL_STATEMENTS:
st = Debug(node.lineno, node.col_offset, node.func.id, 'called') st = Debug(node.lineno, node.col_offset, node.func.id, 'called')
self.breakpoints.append(st) self.breakpoints.append(st)
self.generic_visit(node) self.generic_visit(node)

View file

@ -30,6 +30,12 @@ def test_finds_breakpoint():
visitor = DebugStatementParser() visitor = DebugStatementParser()
visitor.visit(ast.parse('breakpoint()')) visitor.visit(ast.parse('breakpoint()'))
assert visitor.breakpoints == [Debug(1, 0, 'breakpoint', 'called')] assert visitor.breakpoints == [Debug(1, 0, 'breakpoint', 'called')]
def test_finds_print():
visitor = DebugStatementParser()
visitor.visit(ast.parse('print()'))
assert visitor.breakpoints == [Debug(1, 0, 'print', 'called')]
def test_returns_one_for_failing_file(tmpdir): def test_returns_one_for_failing_file(tmpdir):