fix nested calls for check-builtin-literals

This commit is contained in:
Anthony Sottile 2025-10-16 09:30:42 -04:00
parent 9ba250d7b3
commit a804ba5239
2 changed files with 22 additions and 23 deletions

View file

@ -26,36 +26,37 @@ class Call(NamedTuple):
class Visitor(ast.NodeVisitor): class Visitor(ast.NodeVisitor):
def __init__( def __init__(
self, self,
ignore: Sequence[str] | None = None, ignore: set[str],
allow_dict_kwargs: bool = True, allow_dict_kwargs: bool = True,
) -> None: ) -> None:
self.builtin_type_calls: list[Call] = [] self.builtin_type_calls: list[Call] = []
self.ignore = set(ignore) if ignore else set()
self.allow_dict_kwargs = allow_dict_kwargs self.allow_dict_kwargs = allow_dict_kwargs
self._disallowed = BUILTIN_TYPES.keys() - ignore
def _check_dict_call(self, node: ast.Call) -> bool: def _check_dict_call(self, node: ast.Call) -> bool:
return self.allow_dict_kwargs and bool(node.keywords) return self.allow_dict_kwargs and bool(node.keywords)
def visit_Call(self, node: ast.Call) -> None: def visit_Call(self, node: ast.Call) -> None:
if not isinstance(node.func, ast.Name): if (
# Ignore functions that are object attributes (`foo.bar()`). # Ignore functions that are object attributes (`foo.bar()`).
# Assume that if the user calls `builtins.list()`, they know what # Assume that if the user calls `builtins.list()`, they know what
# they're doing. # they're doing.
return isinstance(node.func, ast.Name) and
if node.func.id not in set(BUILTIN_TYPES).difference(self.ignore): node.func.id in self._disallowed and
return (node.func.id != 'dict' or not self._check_dict_call(node)) and
if node.func.id == 'dict' and self._check_dict_call(node): not node.args
return ):
elif node.args: self.builtin_type_calls.append(
return Call(node.func.id, node.lineno, node.col_offset),
self.builtin_type_calls.append( )
Call(node.func.id, node.lineno, node.col_offset),
) self.generic_visit(node)
def check_file( def check_file(
filename: str, filename: str,
ignore: Sequence[str] | None = None, *,
ignore: set[str],
allow_dict_kwargs: bool = True, allow_dict_kwargs: bool = True,
) -> list[Call]: ) -> list[Call]:
with open(filename, 'rb') as f: with open(filename, 'rb') as f:

View file

@ -38,11 +38,6 @@ t1 = ()
''' '''
@pytest.fixture
def visitor():
return Visitor()
@pytest.mark.parametrize( @pytest.mark.parametrize(
('expression', 'calls'), ('expression', 'calls'),
[ [
@ -85,7 +80,8 @@ def visitor():
('builtins.tuple()', []), ('builtins.tuple()', []),
], ],
) )
def test_non_dict_exprs(visitor, expression, calls): def test_non_dict_exprs(expression, calls):
visitor = Visitor(ignore=set())
visitor.visit(ast.parse(expression)) visitor.visit(ast.parse(expression))
assert visitor.builtin_type_calls == calls assert visitor.builtin_type_calls == calls
@ -102,7 +98,8 @@ def test_non_dict_exprs(visitor, expression, calls):
('builtins.dict()', []), ('builtins.dict()', []),
], ],
) )
def test_dict_allow_kwargs_exprs(visitor, expression, calls): def test_dict_allow_kwargs_exprs(expression, calls):
visitor = Visitor(ignore=set())
visitor.visit(ast.parse(expression)) visitor.visit(ast.parse(expression))
assert visitor.builtin_type_calls == calls assert visitor.builtin_type_calls == calls
@ -114,17 +111,18 @@ def test_dict_allow_kwargs_exprs(visitor, expression, calls):
('dict(a=1, b=2, c=3)', [Call('dict', 1, 0)]), ('dict(a=1, b=2, c=3)', [Call('dict', 1, 0)]),
("dict(**{'a': 1, 'b': 2, 'c': 3})", [Call('dict', 1, 0)]), ("dict(**{'a': 1, 'b': 2, 'c': 3})", [Call('dict', 1, 0)]),
('builtins.dict()', []), ('builtins.dict()', []),
pytest.param('f(dict())', [Call('dict', 1, 2)], id='nested'),
], ],
) )
def test_dict_no_allow_kwargs_exprs(expression, calls): def test_dict_no_allow_kwargs_exprs(expression, calls):
visitor = Visitor(allow_dict_kwargs=False) visitor = Visitor(ignore=set(), allow_dict_kwargs=False)
visitor.visit(ast.parse(expression)) visitor.visit(ast.parse(expression))
assert visitor.builtin_type_calls == calls assert visitor.builtin_type_calls == calls
def test_ignore_constructors(): def test_ignore_constructors():
visitor = Visitor( visitor = Visitor(
ignore=('complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'), ignore={'complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'},
) )
visitor.visit(ast.parse(BUILTIN_CONSTRUCTORS)) visitor.visit(ast.parse(BUILTIN_CONSTRUCTORS))
assert visitor.builtin_type_calls == [] assert visitor.builtin_type_calls == []