[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2024-04-13 00:00:18 +00:00
parent 72ad6dc953
commit f4cd1ba0d6
813 changed files with 66015 additions and 58839 deletions

View file

@ -39,4 +39,4 @@ repos:
rev: v1.9.0
hooks:
- id: mypy
additional_dependencies: [types-all]
additional_dependencies: [types-all]

View file

@ -44,7 +44,7 @@ command:
PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
For more information on Execution Policies:
For more information on Execution Policies:
https://go.microsoft.com/fwlink/?LinkID=135170
#>

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
import sys
from coverage.cmdline import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
import sys
from coverage.cmdline import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
import sys
from coverage.cmdline import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
import sys
from pytest import console_main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
import sys
from pytest import console_main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,16 +1,20 @@
import sys
from __future__ import annotations
import importlib
import os
import re
import importlib
import sys
import warnings
is_pypy = '__pypy__' in sys.builtin_module_names
warnings.filterwarnings('ignore',
r'.+ distutils\b.+ deprecated',
DeprecationWarning)
warnings.filterwarnings(
'ignore',
r'.+ distutils\b.+ deprecated',
DeprecationWarning,
)
def warn_distutils_present():
@ -21,18 +25,19 @@ def warn_distutils_present():
# https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250
return
warnings.warn(
"Distutils was imported before Setuptools, but importing Setuptools "
"also replaces the `distutils` module in `sys.modules`. This may lead "
"to undesirable behaviors or errors. To avoid these issues, avoid "
"using distutils directly, ensure that setuptools is installed in the "
"traditional way (e.g. not an editable install), and/or make sure "
"that setuptools is always imported before distutils.")
'Distutils was imported before Setuptools, but importing Setuptools '
'also replaces the `distutils` module in `sys.modules`. This may lead '
'to undesirable behaviors or errors. To avoid these issues, avoid '
'using distutils directly, ensure that setuptools is installed in the '
'traditional way (e.g. not an editable install), and/or make sure '
'that setuptools is always imported before distutils.',
)
def clear_distutils():
if 'distutils' not in sys.modules:
return
warnings.warn("Setuptools is replacing distutils.")
warnings.warn('Setuptools is replacing distutils.')
mods = [name for name in sys.modules if re.match(r'distutils\b', name)]
for name in mods:
del sys.modules[name]
@ -74,7 +79,7 @@ class DistutilsMetaFinder:
if path is not None:
return
method_name = 'spec_for_{fullname}'.format(**locals())
method_name = f'spec_for_{fullname}'
method = getattr(self, method_name, lambda: None)
return method()

View file

@ -1 +1,2 @@
from __future__ import annotations
__import__('_distutils_hack').do_override()

View file

@ -1,4 +1,5 @@
__all__ = ["__version__", "version_tuple"]
from __future__ import annotations
__all__ = ['__version__', 'version_tuple']
try:
from ._version import version as __version__
@ -6,5 +7,5 @@ try:
except ImportError: # pragma: no cover
# broken installation, we don't even try
# unknown only works because we do poor mans version compare
__version__ = "unknown"
version_tuple = (0, 0, "unknown") # type:ignore[assignment]
__version__ = 'unknown'
version_tuple = (0, 0, 'unknown') # type:ignore[assignment]

View file

@ -61,11 +61,12 @@ If things do not work right away:
which should throw a KeyError: 'COMPLINE' (which is properly set by the
global argcomplete script).
"""
from __future__ import annotations
import argparse
from glob import glob
import os
import sys
from glob import glob
from typing import Any
from typing import List
from typing import Optional
@ -77,7 +78,7 @@ class FastFilesCompleter:
def __init__(self, directories: bool = True) -> None:
self.directories = directories
def __call__(self, prefix: str, **kwargs: Any) -> List[str]:
def __call__(self, prefix: str, **kwargs: Any) -> list[str]:
# Only called on non option completions.
if os.sep in prefix[1:]:
prefix_dir = len(os.path.dirname(prefix) + os.sep)
@ -85,26 +86,26 @@ class FastFilesCompleter:
prefix_dir = 0
completion = []
globbed = []
if "*" not in prefix and "?" not in prefix:
if '*' not in prefix and '?' not in prefix:
# We are on unix, otherwise no bash.
if not prefix or prefix[-1] == os.sep:
globbed.extend(glob(prefix + ".*"))
prefix += "*"
globbed.extend(glob(prefix + '.*'))
prefix += '*'
globbed.extend(glob(prefix))
for x in sorted(globbed):
if os.path.isdir(x):
x += "/"
x += '/'
# Append stripping the prefix (like bash, not like compgen).
completion.append(x[prefix_dir:])
return completion
if os.environ.get("_ARGCOMPLETE"):
if os.environ.get('_ARGCOMPLETE'):
try:
import argcomplete.completers
except ImportError:
sys.exit(-1)
filescompleter: Optional[FastFilesCompleter] = FastFilesCompleter()
filescompleter: FastFilesCompleter | None = FastFilesCompleter()
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
argcomplete.autocomplete(parser, always_complete_options=False)

View file

@ -1,4 +1,5 @@
"""Python inspection/code generation API."""
from __future__ import annotations
from .code import Code
from .code import ExceptionInfo
@ -12,13 +13,13 @@ from .source import Source
__all__ = [
"Code",
"ExceptionInfo",
"filter_traceback",
"Frame",
"getfslineno",
"getrawcode",
"Traceback",
"TracebackEntry",
"Source",
'Code',
'ExceptionInfo',
'filter_traceback',
'Frame',
'getfslineno',
'getrawcode',
'Traceback',
'TracebackEntry',
'Source',
]

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,13 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import ast
from bisect import bisect_right
import inspect
import textwrap
import tokenize
import types
import warnings
from bisect import bisect_right
from typing import Iterable
from typing import Iterator
from typing import List
@ -12,7 +15,6 @@ from typing import Optional
from typing import overload
from typing import Tuple
from typing import Union
import warnings
class Source:
@ -23,20 +25,20 @@ class Source:
def __init__(self, obj: object = None) -> None:
if not obj:
self.lines: List[str] = []
self.lines: list[str] = []
elif isinstance(obj, Source):
self.lines = obj.lines
elif isinstance(obj, (tuple, list)):
self.lines = deindent(x.rstrip("\n") for x in obj)
self.lines = deindent(x.rstrip('\n') for x in obj)
elif isinstance(obj, str):
self.lines = deindent(obj.split("\n"))
self.lines = deindent(obj.split('\n'))
else:
try:
rawcode = getrawcode(obj)
src = inspect.getsource(rawcode)
except TypeError:
src = inspect.getsource(obj) # type: ignore[arg-type]
self.lines = deindent(src.split("\n"))
self.lines = deindent(src.split('\n'))
def __eq__(self, other: object) -> bool:
if not isinstance(other, Source):
@ -51,17 +53,17 @@ class Source:
...
@overload
def __getitem__(self, key: slice) -> "Source":
def __getitem__(self, key: slice) -> Source:
...
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]:
def __getitem__(self, key: int | slice) -> str | Source:
if isinstance(key, int):
return self.lines[key]
else:
if key.step not in (None, 1):
raise IndexError("cannot slice a Source with a step")
raise IndexError('cannot slice a Source with a step')
newsource = Source()
newsource.lines = self.lines[key.start : key.stop]
newsource.lines = self.lines[key.start: key.stop]
return newsource
def __iter__(self) -> Iterator[str]:
@ -70,7 +72,7 @@ class Source:
def __len__(self) -> int:
return len(self.lines)
def strip(self) -> "Source":
def strip(self) -> Source:
"""Return new Source object with trailing and leading blank lines removed."""
start, end = 0, len(self)
while start < end and not self.lines[start].strip():
@ -81,35 +83,35 @@ class Source:
source.lines[:] = self.lines[start:end]
return source
def indent(self, indent: str = " " * 4) -> "Source":
def indent(self, indent: str = ' ' * 4) -> Source:
"""Return a copy of the source object with all lines indented by the
given indent-string."""
newsource = Source()
newsource.lines = [(indent + line) for line in self.lines]
return newsource
def getstatement(self, lineno: int) -> "Source":
def getstatement(self, lineno: int) -> Source:
"""Return Source statement which contains the given linenumber
(counted from 0)."""
start, end = self.getstatementrange(lineno)
return self[start:end]
def getstatementrange(self, lineno: int) -> Tuple[int, int]:
def getstatementrange(self, lineno: int) -> tuple[int, int]:
"""Return (start, end) tuple which spans the minimal statement region
which containing the given lineno."""
if not (0 <= lineno < len(self)):
raise IndexError("lineno out of range")
raise IndexError('lineno out of range')
ast, start, end = getstatementrange_ast(lineno, self)
return start, end
def deindent(self) -> "Source":
def deindent(self) -> Source:
"""Return a new Source object deindented."""
newsource = Source()
newsource.lines[:] = deindent(self.lines)
return newsource
def __str__(self) -> str:
return "\n".join(self.lines)
return '\n'.join(self.lines)
#
@ -117,7 +119,7 @@ class Source:
#
def findsource(obj) -> Tuple[Optional[Source], int]:
def findsource(obj) -> tuple[Source | None, int]:
try:
sourcelines, lineno = inspect.findsource(obj)
except Exception:
@ -134,20 +136,20 @@ def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
except AttributeError:
pass
if trycall:
call = getattr(obj, "__call__", None)
call = getattr(obj, '__call__', None)
if call and not isinstance(obj, type):
return getrawcode(call, trycall=False)
raise TypeError(f"could not get code object for {obj!r}")
raise TypeError(f'could not get code object for {obj!r}')
def deindent(lines: Iterable[str]) -> List[str]:
return textwrap.dedent("\n".join(lines)).splitlines()
def deindent(lines: Iterable[str]) -> list[str]:
return textwrap.dedent('\n'.join(lines)).splitlines()
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
def get_statement_startend2(lineno: int, node: ast.AST) -> tuple[int, int | None]:
# Flatten all statements and except handlers into one lineno-list.
# AST's line numbers start indexing at 1.
values: List[int] = []
values: list[int] = []
for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
# The lineno points to the class/def, so need to include the decorators.
@ -155,8 +157,8 @@ def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[i
for d in x.decorator_list:
values.append(d.lineno - 1)
values.append(x.lineno - 1)
for name in ("finalbody", "orelse"):
val: Optional[List[ast.stmt]] = getattr(x, name, None)
for name in ('finalbody', 'orelse'):
val: list[ast.stmt] | None = getattr(x, name, None)
if val:
# Treat the finally/orelse part as its own statement.
values.append(val[0].lineno - 1 - 1)
@ -174,15 +176,15 @@ def getstatementrange_ast(
lineno: int,
source: Source,
assertion: bool = False,
astnode: Optional[ast.AST] = None,
) -> Tuple[ast.AST, int, int]:
astnode: ast.AST | None = None,
) -> tuple[ast.AST, int, int]:
if astnode is None:
content = str(source)
# See #4260:
# Don't produce duplicate warnings when compiling source to find AST.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
astnode = ast.parse(content, "source", "exec")
warnings.simplefilter('ignore')
astnode = ast.parse(content, 'source', 'exec')
start, end = get_statement_startend2(lineno, astnode)
# We need to correct the end:
@ -200,7 +202,7 @@ def getstatementrange_ast(
block_finder.started = (
bool(source.lines[start]) and source.lines[start][0].isspace()
)
it = ((x + "\n") for x in source.lines[start:end])
it = ((x + '\n') for x in source.lines[start:end])
try:
for tok in tokenize.generate_tokens(lambda: next(it)):
block_finder.tokeneater(*tok)
@ -212,7 +214,7 @@ def getstatementrange_ast(
# The end might still point to a comment or empty line, correct it.
while end:
line = source.lines[end - 1].lstrip()
if line.startswith("#") or not line:
if line.startswith('#') or not line:
end -= 1
else:
break

View file

@ -1,8 +1,10 @@
from __future__ import annotations
from .terminalwriter import get_terminal_width
from .terminalwriter import TerminalWriter
__all__ = [
"TerminalWriter",
"get_terminal_width",
'TerminalWriter',
'get_terminal_width',
]

View file

@ -13,11 +13,13 @@
# tuples with fairly non-descriptive content. This is modeled very much
# after Lisp/Scheme - style pretty-printing of lists. If you find it
# useful, thank small children who sleep at night.
from __future__ import annotations
import collections as _collections
import dataclasses as _dataclasses
from io import StringIO as _StringIO
import re
import types as _types
from io import StringIO as _StringIO
from typing import Any
from typing import Callable
from typing import Dict
@ -39,7 +41,7 @@ class _safe_key:
"""
__slots__ = ["obj"]
__slots__ = ['obj']
def __init__(self, obj):
self.obj = obj
@ -64,7 +66,7 @@ class PrettyPrinter:
self,
indent: int = 4,
width: int = 80,
depth: Optional[int] = None,
depth: int | None = None,
) -> None:
"""Handle pretty printing operations onto a stream using a set of
configured parameters.
@ -80,11 +82,11 @@ class PrettyPrinter:
"""
if indent < 0:
raise ValueError("indent must be >= 0")
raise ValueError('indent must be >= 0')
if depth is not None and depth <= 0:
raise ValueError("depth must be > 0")
raise ValueError('depth must be > 0')
if not width:
raise ValueError("width must be != 0")
raise ValueError('width must be != 0')
self._depth = depth
self._indent_per_level = indent
self._width = width
@ -100,7 +102,7 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
objid = id(object)
@ -114,17 +116,16 @@ class PrettyPrinter:
p(self, object, stream, indent, allowance, context, level + 1)
context.remove(objid)
elif (
_dataclasses.is_dataclass(object)
and not isinstance(object, type)
and object.__dataclass_params__.repr
and
_dataclasses.is_dataclass(object) and
not isinstance(object, type) and
object.__dataclass_params__.repr and
# Check dataclass has generated repr method.
hasattr(object.__repr__, "__wrapped__")
and "__create_fn__" in object.__repr__.__wrapped__.__qualname__
hasattr(object.__repr__, '__wrapped__') and
'__create_fn__' in object.__repr__.__wrapped__.__qualname__
):
context.add(objid)
self._pprint_dataclass(
object, stream, indent, allowance, context, level + 1
object, stream, indent, allowance, context, level + 1,
)
context.remove(objid)
else:
@ -136,7 +137,7 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
cls_name = object.__class__.__name__
@ -145,13 +146,13 @@ class PrettyPrinter:
for f in _dataclasses.fields(object)
if f.repr
]
stream.write(cls_name + "(")
stream.write(cls_name + '(')
self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")")
stream.write(')')
_dispatch: Dict[
_dispatch: dict[
Callable[..., str],
Callable[["PrettyPrinter", Any, IO[str], int, int, Set[int], int], None],
Callable[[PrettyPrinter, Any, IO[str], int, int, set[int], int], None],
] = {}
def _pprint_dict(
@ -160,14 +161,14 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
write = stream.write
write("{")
write('{')
items = sorted(object.items(), key=_safe_tuple)
self._format_dict_items(items, stream, indent, allowance, context, level)
write("}")
write('}')
_dispatch[dict.__repr__] = _pprint_dict
@ -177,16 +178,16 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
if not len(object):
stream.write(repr(object))
return
cls = object.__class__
stream.write(cls.__name__ + "(")
stream.write(cls.__name__ + '(')
self._pprint_dict(object, stream, indent, allowance, context, level)
stream.write(")")
stream.write(')')
_dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict
@ -196,12 +197,12 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
stream.write("[")
stream.write('[')
self._format_items(object, stream, indent, allowance, context, level)
stream.write("]")
stream.write(']')
_dispatch[list.__repr__] = _pprint_list
@ -211,12 +212,12 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
stream.write("(")
stream.write('(')
self._format_items(object, stream, indent, allowance, context, level)
stream.write(")")
stream.write(')')
_dispatch[tuple.__repr__] = _pprint_tuple
@ -226,7 +227,7 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
if not len(object):
@ -234,11 +235,11 @@ class PrettyPrinter:
return
typ = object.__class__
if typ is set:
stream.write("{")
endchar = "}"
stream.write('{')
endchar = '}'
else:
stream.write(typ.__name__ + "({")
endchar = "})"
stream.write(typ.__name__ + '({')
endchar = '})'
object = sorted(object, key=_safe_key)
self._format_items(object, stream, indent, allowance, context, level)
stream.write(endchar)
@ -252,7 +253,7 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
write = stream.write
@ -273,12 +274,12 @@ class PrettyPrinter:
chunks.append(rep)
else:
# A list of alternating (non-space, space) strings
parts = re.findall(r"\S*\s*", line)
parts = re.findall(r'\S*\s*', line)
assert parts
assert not parts[-1]
parts.pop() # drop empty last part
max_width2 = max_width
current = ""
current = ''
for j, part in enumerate(parts):
candidate = current + part
if j == len(parts) - 1 and i == len(lines) - 1:
@ -295,13 +296,13 @@ class PrettyPrinter:
write(rep)
return
if level == 1:
write("(")
write('(')
for i, rep in enumerate(chunks):
if i > 0:
write("\n" + " " * indent)
write('\n' + ' ' * indent)
write(rep)
if level == 1:
write(")")
write(')')
_dispatch[str.__repr__] = _pprint_str
@ -311,7 +312,7 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
write = stream.write
@ -322,15 +323,15 @@ class PrettyPrinter:
if parens:
indent += 1
allowance += 1
write("(")
delim = ""
write('(')
delim = ''
for rep in _wrap_bytes_repr(object, self._width - indent, allowance):
write(delim)
write(rep)
if not delim:
delim = "\n" + " " * indent
delim = '\n' + ' ' * indent
if parens:
write(")")
write(')')
_dispatch[bytes.__repr__] = _pprint_bytes
@ -340,15 +341,15 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
write = stream.write
write("bytearray(")
write('bytearray(')
self._pprint_bytes(
bytes(object), stream, indent + 10, allowance + 1, context, level + 1
bytes(object), stream, indent + 10, allowance + 1, context, level + 1,
)
write(")")
write(')')
_dispatch[bytearray.__repr__] = _pprint_bytearray
@ -358,12 +359,12 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
stream.write("mappingproxy(")
stream.write('mappingproxy(')
self._format(object.copy(), stream, indent, allowance, context, level)
stream.write(")")
stream.write(')')
_dispatch[_types.MappingProxyType.__repr__] = _pprint_mappingproxy
@ -373,29 +374,29 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
if type(object) is _types.SimpleNamespace:
# The SimpleNamespace repr is "namespace" instead of the class
# name, so we do the same here. For subclasses; use the class name.
cls_name = "namespace"
cls_name = 'namespace'
else:
cls_name = object.__class__.__name__
items = object.__dict__.items()
stream.write(cls_name + "(")
stream.write(cls_name + '(')
self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")")
stream.write(')')
_dispatch[_types.SimpleNamespace.__repr__] = _pprint_simplenamespace
def _format_dict_items(
self,
items: List[Tuple[Any, Any]],
items: list[tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
if not items:
@ -403,23 +404,23 @@ class PrettyPrinter:
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
delimnl = '\n' + ' ' * item_indent
for key, ent in items:
write(delimnl)
write(self._repr(key, context, level))
write(": ")
write(': ')
self._format(ent, stream, item_indent, 1, context, level)
write(",")
write(',')
write("\n" + " " * indent)
write('\n' + ' ' * indent)
def _format_namespace_items(
self,
items: List[Tuple[Any, Any]],
items: list[tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
if not items:
@ -427,15 +428,15 @@ class PrettyPrinter:
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
delimnl = '\n' + ' ' * item_indent
for key, ent in items:
write(delimnl)
write(key)
write("=")
write('=')
if id(ent) in context:
# Special-case representation of recursion to match standard
# recursive dataclass repr.
write("...")
write('...')
else:
self._format(
ent,
@ -446,17 +447,17 @@ class PrettyPrinter:
level,
)
write(",")
write(',')
write("\n" + " " * indent)
write('\n' + ' ' * indent)
def _format_items(
self,
items: List[Any],
items: list[Any],
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
if not items:
@ -464,16 +465,16 @@ class PrettyPrinter:
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
delimnl = '\n' + ' ' * item_indent
for item in items:
write(delimnl)
self._format(item, stream, item_indent, 1, context, level)
write(",")
write(',')
write("\n" + " " * indent)
write('\n' + ' ' * indent)
def _repr(self, object: Any, context: Set[int], level: int) -> str:
def _repr(self, object: Any, context: set[int], level: int) -> str:
return self._safe_repr(object, context.copy(), self._depth, level)
def _pprint_default_dict(
@ -482,13 +483,13 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
rdf = self._repr(object.default_factory, context, level)
stream.write(f"{object.__class__.__name__}({rdf}, ")
stream.write(f'{object.__class__.__name__}({rdf}, ')
self._pprint_dict(object, stream, indent, allowance, context, level)
stream.write(")")
stream.write(')')
_dispatch[_collections.defaultdict.__repr__] = _pprint_default_dict
@ -498,18 +499,18 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(")
stream.write(object.__class__.__name__ + '(')
if object:
stream.write("{")
stream.write('{')
items = object.most_common()
self._format_dict_items(items, stream, indent, allowance, context, level)
stream.write("}")
stream.write('}')
stream.write(")")
stream.write(')')
_dispatch[_collections.Counter.__repr__] = _pprint_counter
@ -519,16 +520,16 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])):
stream.write(repr(object))
return
stream.write(object.__class__.__name__ + "(")
stream.write(object.__class__.__name__ + '(')
self._format_items(object.maps, stream, indent, allowance, context, level)
stream.write(")")
stream.write(')')
_dispatch[_collections.ChainMap.__repr__] = _pprint_chain_map
@ -538,16 +539,16 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(")
stream.write(object.__class__.__name__ + '(')
if object.maxlen is not None:
stream.write("maxlen=%d, " % object.maxlen)
stream.write("[")
stream.write('maxlen=%d, ' % object.maxlen)
stream.write('[')
self._format_items(object, stream, indent, allowance + 1, context, level)
stream.write("])")
stream.write('])')
_dispatch[_collections.deque.__repr__] = _pprint_deque
@ -557,7 +558,7 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
@ -570,7 +571,7 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
@ -583,7 +584,7 @@ class PrettyPrinter:
stream: IO[str],
indent: int,
allowance: int,
context: Set[int],
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
@ -591,49 +592,49 @@ class PrettyPrinter:
_dispatch[_collections.UserString.__repr__] = _pprint_user_string
def _safe_repr(
self, object: Any, context: Set[int], maxlevels: Optional[int], level: int
self, object: Any, context: set[int], maxlevels: int | None, level: int,
) -> str:
typ = type(object)
if typ in _builtin_scalars:
return repr(object)
r = getattr(typ, "__repr__", None)
r = getattr(typ, '__repr__', None)
if issubclass(typ, dict) and r is dict.__repr__:
if not object:
return "{}"
return '{}'
objid = id(object)
if maxlevels and level >= maxlevels:
return "{...}"
return '{...}'
if objid in context:
return _recursion(object)
context.add(objid)
components: List[str] = []
components: list[str] = []
append = components.append
level += 1
for k, v in sorted(object.items(), key=_safe_tuple):
krepr = self._safe_repr(k, context, maxlevels, level)
vrepr = self._safe_repr(v, context, maxlevels, level)
append(f"{krepr}: {vrepr}")
append(f'{krepr}: {vrepr}')
context.remove(objid)
return "{%s}" % ", ".join(components)
return '{%s}' % ', '.join(components)
if (issubclass(typ, list) and r is list.__repr__) or (
issubclass(typ, tuple) and r is tuple.__repr__
):
if issubclass(typ, list):
if not object:
return "[]"
format = "[%s]"
return '[]'
format = '[%s]'
elif len(object) == 1:
format = "(%s,)"
format = '(%s,)'
else:
if not object:
return "()"
format = "(%s)"
return '()'
format = '(%s)'
objid = id(object)
if maxlevels and level >= maxlevels:
return format % "..."
return format % '...'
if objid in context:
return _recursion(object)
context.add(objid)
@ -644,25 +645,25 @@ class PrettyPrinter:
orepr = self._safe_repr(o, context, maxlevels, level)
append(orepr)
context.remove(objid)
return format % ", ".join(components)
return format % ', '.join(components)
return repr(object)
_builtin_scalars = frozenset(
{str, bytes, bytearray, float, complex, bool, type(None), int}
{str, bytes, bytearray, float, complex, bool, type(None), int},
)
def _recursion(object: Any) -> str:
return f"<Recursion on {type(object).__name__} with id={id(object)}>"
return f'<Recursion on {type(object).__name__} with id={id(object)}>'
def _wrap_bytes_repr(object: Any, width: int, allowance: int) -> Iterator[str]:
current = b""
current = b''
last = len(object) // 4 * 4
for i in range(0, len(object), 4):
part = object[i : i + 4]
part = object[i: i + 4]
candidate = current + part
if i == last:
width -= allowance

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import pprint
import reprlib
from typing import Optional
@ -18,9 +20,9 @@ def _format_repr_exception(exc: BaseException, obj: object) -> str:
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
exc_info = f"unpresentable exception ({_try_repr_or_str(exc)})"
exc_info = f'unpresentable exception ({_try_repr_or_str(exc)})'
return (
f"<[{exc_info} raised in repr()] {type(obj).__name__} object at 0x{id(obj):x}>"
f'<[{exc_info} raised in repr()] {type(obj).__name__} object at 0x{id(obj):x}>'
)
@ -28,7 +30,7 @@ def _ellipsize(s: str, maxsize: int) -> str:
if len(s) > maxsize:
i = max(0, (maxsize - 3) // 2)
j = max(0, maxsize - 3 - i)
return s[:i] + "..." + s[len(s) - j :]
return s[:i] + '...' + s[len(s) - j:]
return s
@ -38,7 +40,7 @@ class SafeRepr(reprlib.Repr):
information on exceptions raised during the call.
"""
def __init__(self, maxsize: Optional[int], use_ascii: bool = False) -> None:
def __init__(self, maxsize: int | None, use_ascii: bool = False) -> None:
"""
:param maxsize:
If not None, will truncate the resulting repr to that specific size, using ellipsis
@ -97,7 +99,7 @@ DEFAULT_REPR_MAX_SIZE = 240
def saferepr(
obj: object, maxsize: Optional[int] = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False
obj: object, maxsize: int | None = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False,
) -> str:
"""Return a size-limited safe repr-string for the given object.

View file

@ -1,4 +1,5 @@
"""Helper functions for writing to terminals and files."""
from __future__ import annotations
import os
import shutil
@ -26,16 +27,16 @@ def get_terminal_width() -> int:
def should_do_markup(file: TextIO) -> bool:
if os.environ.get("PY_COLORS") == "1":
if os.environ.get('PY_COLORS') == '1':
return True
if os.environ.get("PY_COLORS") == "0":
if os.environ.get('PY_COLORS') == '0':
return False
if os.environ.get("NO_COLOR"):
if os.environ.get('NO_COLOR'):
return False
if os.environ.get("FORCE_COLOR"):
if os.environ.get('FORCE_COLOR'):
return True
return (
hasattr(file, "isatty") and file.isatty() and os.environ.get("TERM") != "dumb"
hasattr(file, 'isatty') and file.isatty() and os.environ.get('TERM') != 'dumb'
)
@ -64,10 +65,10 @@ class TerminalWriter:
invert=7,
)
def __init__(self, file: Optional[TextIO] = None) -> None:
def __init__(self, file: TextIO | None = None) -> None:
if file is None:
file = sys.stdout
if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32":
if hasattr(file, 'isatty') and file.isatty() and sys.platform == 'win32':
try:
import colorama
except ImportError:
@ -77,8 +78,8 @@ class TerminalWriter:
assert file is not None
self._file = file
self.hasmarkup = should_do_markup(file)
self._current_line = ""
self._terminal_width: Optional[int] = None
self._current_line = ''
self._terminal_width: int | None = None
self.code_highlight = True
@property
@ -99,25 +100,25 @@ class TerminalWriter:
def markup(self, text: str, **markup: bool) -> str:
for name in markup:
if name not in self._esctable:
raise ValueError(f"unknown markup: {name!r}")
raise ValueError(f'unknown markup: {name!r}')
if self.hasmarkup:
esc = [self._esctable[name] for name, on in markup.items() if on]
if esc:
text = "".join("\x1b[%sm" % cod for cod in esc) + text + "\x1b[0m"
text = ''.join('\x1b[%sm' % cod for cod in esc) + text + '\x1b[0m'
return text
def sep(
self,
sepchar: str,
title: Optional[str] = None,
fullwidth: Optional[int] = None,
title: str | None = None,
fullwidth: int | None = None,
**markup: bool,
) -> None:
if fullwidth is None:
fullwidth = self.fullwidth
# The goal is to have the line be as long as possible
# under the condition that len(line) <= fullwidth.
if sys.platform == "win32":
if sys.platform == 'win32':
# If we print in the last column on windows we are on a
# new line but there is no way to verify/neutralize this
# (we may not know the exact line width).
@ -130,7 +131,7 @@ class TerminalWriter:
# N <= (fullwidth - len(title) - 2) // (2*len(sepchar))
N = max((fullwidth - len(title) - 2) // (2 * len(sepchar)), 1)
fill = sepchar * N
line = f"{fill} {title} {fill}"
line = f'{fill} {title} {fill}'
else:
# we want len(sepchar)*N <= fullwidth
# i.e. N <= fullwidth // len(sepchar)
@ -145,8 +146,8 @@ class TerminalWriter:
def write(self, msg: str, *, flush: bool = False, **markup: bool) -> None:
if msg:
current_line = msg.rsplit("\n", 1)[-1]
if "\n" in msg:
current_line = msg.rsplit('\n', 1)[-1]
if '\n' in msg:
self._current_line = current_line
else:
self._current_line += current_line
@ -162,15 +163,15 @@ class TerminalWriter:
# When the Unicode situation improves we should consider
# letting the error propagate instead of masking it (see #7475
# for one brief attempt).
msg = msg.encode("unicode-escape").decode("ascii")
msg = msg.encode('unicode-escape').decode('ascii')
self._file.write(msg)
if flush:
self.flush()
def line(self, s: str = "", **markup: bool) -> None:
def line(self, s: str = '', **markup: bool) -> None:
self.write(s, **markup)
self.write("\n")
self.write('\n')
def flush(self) -> None:
self._file.flush()
@ -184,17 +185,17 @@ class TerminalWriter:
"""
if indents and len(indents) != len(lines):
raise ValueError(
f"indents size ({len(indents)}) should have same size as lines ({len(lines)})"
f'indents size ({len(indents)}) should have same size as lines ({len(lines)})',
)
if not indents:
indents = [""] * len(lines)
source = "\n".join(lines)
indents = [''] * len(lines)
source = '\n'.join(lines)
new_lines = self._highlight(source).splitlines()
for indent, new_line in zip(indents, new_lines):
self.line(indent + new_line)
def _highlight(
self, source: str, lexer: Literal["diff", "python"] = "python"
self, source: str, lexer: Literal['diff', 'python'] = 'python',
) -> str:
"""Highlight the given source if we have markup support."""
from _pytest.config.exceptions import UsageError
@ -205,9 +206,9 @@ class TerminalWriter:
try:
from pygments.formatters.terminal import TerminalFormatter
if lexer == "python":
if lexer == 'python':
from pygments.lexers.python import PythonLexer as Lexer
elif lexer == "diff":
elif lexer == 'diff':
from pygments.lexers.diff import DiffLexer as Lexer
from pygments import highlight
import pygments.util
@ -219,30 +220,30 @@ class TerminalWriter:
source,
Lexer(),
TerminalFormatter(
bg=os.getenv("PYTEST_THEME_MODE", "dark"),
style=os.getenv("PYTEST_THEME"),
bg=os.getenv('PYTEST_THEME_MODE', 'dark'),
style=os.getenv('PYTEST_THEME'),
),
)
# pygments terminal formatter may add a newline when there wasn't one.
# We don't want this, remove.
if highlighted[-1] == "\n" and source[-1] != "\n":
if highlighted[-1] == '\n' and source[-1] != '\n':
highlighted = highlighted[:-1]
# Some lexers will not set the initial color explicitly
# which may lead to the previous color being propagated to the
# start of the expression, so reset first.
return "\x1b[0m" + highlighted
return '\x1b[0m' + highlighted
except pygments.util.ClassNotFound as e:
raise UsageError(
"PYTEST_THEME environment variable had an invalid value: '{}'. "
"Only valid pygment styles are allowed.".format(
os.getenv("PYTEST_THEME")
)
'Only valid pygment styles are allowed.'.format(
os.getenv('PYTEST_THEME'),
),
) from e
except pygments.util.OptionError as e:
raise UsageError(
"PYTEST_THEME_MODE environment variable had an invalid value: '{}'. "
"The only allowed values are 'dark' and 'light'.".format(
os.getenv("PYTEST_THEME_MODE")
)
os.getenv('PYTEST_THEME_MODE'),
),
) from e

View file

@ -1,5 +1,7 @@
from functools import lru_cache
from __future__ import annotations
import unicodedata
from functools import lru_cache
@lru_cache(100)
@ -17,25 +19,25 @@ def wcwidth(c: str) -> int:
# Some Cf/Zp/Zl characters which should be zero-width.
if (
o == 0x0000
or 0x200B <= o <= 0x200F
or 0x2028 <= o <= 0x202E
or 0x2060 <= o <= 0x2063
o == 0x0000 or
0x200B <= o <= 0x200F or
0x2028 <= o <= 0x202E or
0x2060 <= o <= 0x2063
):
return 0
category = unicodedata.category(c)
# Control characters.
if category == "Cc":
if category == 'Cc':
return -1
# Combining characters with zero width.
if category in ("Me", "Mn"):
if category in ('Me', 'Mn'):
return 0
# Full/Wide east asian characters.
if unicodedata.east_asian_width(c) in ("F", "W"):
if unicodedata.east_asian_width(c) in ('F', 'W'):
return 2
return 1
@ -47,7 +49,7 @@ def wcswidth(s: str) -> int:
Returns -1 if the string contains non-printable characters.
"""
width = 0
for c in unicodedata.normalize("NFC", s):
for c in unicodedata.normalize('NFC', s):
wc = wcwidth(c)
if wc < 0:
return -1

View file

@ -1,5 +1,4 @@
"""create errno-specific classes for IO or os calls."""
from __future__ import annotations
import errno
@ -13,25 +12,25 @@ from typing import TypeVar
if TYPE_CHECKING:
from typing_extensions import ParamSpec
P = ParamSpec("P")
P = ParamSpec('P')
R = TypeVar("R")
R = TypeVar('R')
class Error(EnvironmentError):
def __repr__(self) -> str:
return "{}.{} {!r}: {} ".format(
return '{}.{} {!r}: {} '.format(
self.__class__.__module__,
self.__class__.__name__,
self.__class__.__doc__,
" ".join(map(str, self.args)),
' '.join(map(str, self.args)),
# repr(self.args)
)
def __str__(self) -> str:
s = "[{}]: {}".format(
s = '[{}]: {}'.format(
self.__class__.__doc__,
" ".join(map(str, self.args)),
' '.join(map(str, self.args)),
)
return s
@ -58,7 +57,7 @@ class ErrorMaker:
_errno2class: dict[int, type[Error]] = {}
def __getattr__(self, name: str) -> type[Error]:
if name[0] == "_":
if name[0] == '_':
raise AttributeError(name)
eno = getattr(errno, name)
cls = self._geterrnoclass(eno)
@ -69,17 +68,17 @@ class ErrorMaker:
try:
return self._errno2class[eno]
except KeyError:
clsname = errno.errorcode.get(eno, "UnknownErrno%d" % (eno,))
clsname = errno.errorcode.get(eno, 'UnknownErrno%d' % (eno,))
errorcls = type(
clsname,
(Error,),
{"__module__": "py.error", "__doc__": os.strerror(eno)},
{'__module__': 'py.error', '__doc__': os.strerror(eno)},
)
self._errno2class[eno] = errorcls
return errorcls
def checked_call(
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs,
) -> R:
"""Call a function and raise an errno-exception if applicable."""
__tracebackhide__ = True
@ -88,10 +87,10 @@ class ErrorMaker:
except Error:
raise
except OSError as value:
if not hasattr(value, "errno"):
if not hasattr(value, 'errno'):
raise
errno = value.errno
if sys.platform == "win32":
if sys.platform == 'win32':
try:
cls = self._geterrnoclass(_winerrnomap[errno])
except KeyError:
@ -100,7 +99,7 @@ class ErrorMaker:
# we are not on Windows, or we got a proper OSError
cls = self._geterrnoclass(errno)
raise cls(f"{func.__name__}{args!r}")
raise cls(f'{func.__name__}{args!r}')
_error_maker = ErrorMaker()

View file

@ -3,11 +3,15 @@
from __future__ import annotations
import atexit
from contextlib import contextmanager
import fnmatch
import importlib.util
import io
import os
import posixpath
import sys
import uuid
import warnings
from contextlib import contextmanager
from os.path import abspath
from os.path import dirname
from os.path import exists
@ -16,39 +20,35 @@ from os.path import isdir
from os.path import isfile
from os.path import islink
from os.path import normpath
import posixpath
from stat import S_ISDIR
from stat import S_ISLNK
from stat import S_ISREG
import sys
from typing import Any
from typing import Callable
from typing import cast
from typing import Literal
from typing import overload
from typing import TYPE_CHECKING
import uuid
import warnings
from . import error
# Moved from local.py.
iswin32 = sys.platform == "win32" or (getattr(os, "_name", False) == "nt")
iswin32 = sys.platform == 'win32' or (getattr(os, '_name', False) == 'nt')
class Checkers:
_depend_on_existence = "exists", "link", "dir", "file"
_depend_on_existence = 'exists', 'link', 'dir', 'file'
def __init__(self, path):
self.path = path
def dotfile(self):
return self.path.basename.startswith(".")
return self.path.basename.startswith('.')
def ext(self, arg):
if not arg.startswith("."):
arg = "." + arg
if not arg.startswith('.'):
arg = '.' + arg
return self.path.ext == arg
def basename(self, arg):
@ -75,14 +75,14 @@ class Checkers:
try:
meth = getattr(self, name)
except AttributeError:
if name[:3] == "not":
if name[:3] == 'not':
invert = True
try:
meth = getattr(self, name[3:])
except AttributeError:
pass
if meth is None:
raise TypeError(f"no {name!r} checker available for {self.path!r}")
raise TypeError(f'no {name!r} checker available for {self.path!r}')
try:
if getrawcode(meth).co_argcount > 1:
if (not meth(value)) ^ invert:
@ -98,7 +98,7 @@ class Checkers:
if name in kw:
if kw.get(name):
return False
name = "not" + name
name = 'not' + name
if name in kw:
if not kw.get(name):
return False
@ -140,7 +140,7 @@ class Visitor:
fil = FNMatcher(fil)
if isinstance(rec, str):
self.rec: Callable[[LocalPath], bool] = FNMatcher(rec)
elif not hasattr(rec, "__call__") and rec:
elif not hasattr(rec, '__call__') and rec:
self.rec = lambda path: True
else:
self.rec = rec
@ -156,7 +156,7 @@ class Visitor:
return
rec = self.rec
dirs = self.optsort(
[p for p in entries if p.check(dir=1) and (rec is None or rec(p))]
[p for p in entries if p.check(dir=1) and (rec is None or rec(p))],
)
if not self.breadthfirst:
for subdir in dirs:
@ -179,9 +179,9 @@ class FNMatcher:
pattern = self.pattern
if (
pattern.find(path.sep) == -1
and iswin32
and pattern.find(posixpath.sep) != -1
pattern.find(path.sep) == -1 and
iswin32 and
pattern.find(posixpath.sep) != -1
):
# Running on Windows, the pattern has no Windows path separators,
# and the pattern has one or more Posix path separators. Replace
@ -193,7 +193,7 @@ class FNMatcher:
else:
name = str(path) # path.strpath # XXX svn?
if not os.path.isabs(pattern):
pattern = "*" + path.sep + pattern
pattern = '*' + path.sep + pattern
return fnmatch.fnmatch(name, pattern)
@ -213,7 +213,7 @@ class Stat:
...
def __getattr__(self, name: str) -> Any:
return getattr(self._osstatresult, "st_" + name)
return getattr(self._osstatresult, 'st_' + name)
def __init__(self, path, osstatresult):
self.path = path
@ -222,7 +222,7 @@ class Stat:
@property
def owner(self):
if iswin32:
raise NotImplementedError("XXX win32")
raise NotImplementedError('XXX win32')
import pwd
entry = error.checked_call(pwd.getpwuid, self.uid) # type:ignore[attr-defined]
@ -232,7 +232,7 @@ class Stat:
def group(self):
"""Return group name of file."""
if iswin32:
raise NotImplementedError("XXX win32")
raise NotImplementedError('XXX win32')
import grp
entry = error.checked_call(grp.getgrgid, self.gid) # type:ignore[attr-defined]
@ -292,14 +292,14 @@ class LocalPath:
path = os.fspath(path)
except TypeError:
raise ValueError(
"can only pass None, Path instances "
"or non-empty strings to LocalPath"
'can only pass None, Path instances '
'or non-empty strings to LocalPath',
)
if expanduser:
path = os.path.expanduser(path)
self.strpath = abspath(path)
if sys.platform != "win32":
if sys.platform != 'win32':
def chown(self, user, group, rec=0):
"""Change ownership to the given user and group.
@ -334,7 +334,7 @@ class LocalPath:
relsource = self.__class__(value).relto(base)
reldest = self.relto(base)
n = reldest.count(self.sep)
target = self.sep.join(("..",) * n + (relsource,))
target = self.sep.join(('..',) * n + (relsource,))
error.checked_call(os.symlink, target, self.strpath)
def __div__(self, other):
@ -345,34 +345,34 @@ class LocalPath:
@property
def basename(self):
"""Basename part of path."""
return self._getbyspec("basename")[0]
return self._getbyspec('basename')[0]
@property
def dirname(self):
"""Dirname part of path."""
return self._getbyspec("dirname")[0]
return self._getbyspec('dirname')[0]
@property
def purebasename(self):
"""Pure base name of the path."""
return self._getbyspec("purebasename")[0]
return self._getbyspec('purebasename')[0]
@property
def ext(self):
"""Extension of the path (including the '.')."""
return self._getbyspec("ext")[0]
return self._getbyspec('ext')[0]
def read_binary(self):
"""Read and return a bytestring from reading the path."""
with self.open("rb") as f:
with self.open('rb') as f:
return f.read()
def read_text(self, encoding):
"""Read and return a Unicode string from reading the path."""
with self.open("r", encoding=encoding) as f:
with self.open('r', encoding=encoding) as f:
return f.read()
def read(self, mode="r"):
def read(self, mode='r'):
"""Read and return a bytestring from reading the path."""
with self.open(mode) as f:
return f.read()
@ -380,11 +380,11 @@ class LocalPath:
def readlines(self, cr=1):
"""Read and return a list of lines from the path. if cr is False, the
newline will be removed from the end of each line."""
mode = "r"
mode = 'r'
if not cr:
content = self.read(mode)
return content.split("\n")
return content.split('\n')
else:
f = self.open(mode)
try:
@ -394,7 +394,7 @@ class LocalPath:
def load(self):
"""(deprecated) return object unpickled from self.read()"""
f = self.open("rb")
f = self.open('rb')
try:
import pickle
@ -405,7 +405,7 @@ class LocalPath:
def move(self, target):
"""Move this path to target."""
if target.relto(self):
raise error.EINVAL(target, "cannot move path into a subdirectory of itself")
raise error.EINVAL(target, 'cannot move path into a subdirectory of itself')
try:
self.rename(target)
except error.EXDEV: # invalid cross-device link
@ -436,19 +436,19 @@ class LocalPath:
to the given 'relpath'.
"""
if not isinstance(relpath, (str, LocalPath)):
raise TypeError(f"{relpath!r}: not a string or path object")
raise TypeError(f'{relpath!r}: not a string or path object')
strrelpath = str(relpath)
if strrelpath and strrelpath[-1] != self.sep:
strrelpath += self.sep
# assert strrelpath[-1] == self.sep
# assert strrelpath[-2] != self.sep
strself = self.strpath
if sys.platform == "win32" or getattr(os, "_name", None) == "nt":
if sys.platform == 'win32' or getattr(os, '_name', None) == 'nt':
if os.path.normcase(strself).startswith(os.path.normcase(strrelpath)):
return strself[len(strrelpath) :]
return strself[len(strrelpath):]
elif strself.startswith(strrelpath):
return strself[len(strrelpath) :]
return ""
return strself[len(strrelpath):]
return ''
def ensure_dir(self, *args):
"""Ensure the path joined with args is a directory."""
@ -542,10 +542,10 @@ class LocalPath:
def _sortlist(self, res, sort):
if sort:
if hasattr(sort, "__call__"):
if hasattr(sort, '__call__'):
warnings.warn(
DeprecationWarning(
"listdir(sort=callable) is deprecated and breaks on python3"
'listdir(sort=callable) is deprecated and breaks on python3',
),
stacklevel=3,
)
@ -592,7 +592,7 @@ class LocalPath:
other = abspath(other)
if self == other:
return True
if not hasattr(os.path, "samefile"):
if not hasattr(os.path, 'samefile'):
return False
return error.checked_call(os.path.samefile, self.strpath, other)
@ -609,7 +609,7 @@ class LocalPath:
import shutil
error.checked_call(
shutil.rmtree, self.strpath, ignore_errors=ignore_errors
shutil.rmtree, self.strpath, ignore_errors=ignore_errors,
)
else:
error.checked_call(os.rmdir, self.strpath)
@ -618,19 +618,19 @@ class LocalPath:
self.chmod(0o700)
error.checked_call(os.remove, self.strpath)
def computehash(self, hashtype="md5", chunksize=524288):
def computehash(self, hashtype='md5', chunksize=524288):
"""Return hexdigest of hashvalue for this file."""
try:
try:
import hashlib as mod
except ImportError:
if hashtype == "sha1":
hashtype = "sha"
if hashtype == 'sha1':
hashtype = 'sha'
mod = __import__(hashtype)
hash = getattr(mod, hashtype)()
except (AttributeError, ImportError):
raise ValueError(f"Don't know how to compute {hashtype!r} hash")
f = self.open("rb")
f = self.open('rb')
try:
while 1:
buf = f.read(chunksize)
@ -656,28 +656,28 @@ class LocalPath:
obj.strpath = self.strpath
return obj
drive, dirname, basename, purebasename, ext = self._getbyspec(
"drive,dirname,basename,purebasename,ext"
'drive,dirname,basename,purebasename,ext',
)
if "basename" in kw:
if "purebasename" in kw or "ext" in kw:
raise ValueError("invalid specification %r" % kw)
if 'basename' in kw:
if 'purebasename' in kw or 'ext' in kw:
raise ValueError('invalid specification %r' % kw)
else:
pb = kw.setdefault("purebasename", purebasename)
pb = kw.setdefault('purebasename', purebasename)
try:
ext = kw["ext"]
ext = kw['ext']
except KeyError:
pass
else:
if ext and not ext.startswith("."):
ext = "." + ext
kw["basename"] = pb + ext
if ext and not ext.startswith('.'):
ext = '.' + ext
kw['basename'] = pb + ext
if "dirname" in kw and not kw["dirname"]:
kw["dirname"] = drive
if 'dirname' in kw and not kw['dirname']:
kw['dirname'] = drive
else:
kw.setdefault("dirname", dirname)
kw.setdefault("sep", self.sep)
obj.strpath = normpath("{dirname}{sep}{basename}".format(**kw))
kw.setdefault('dirname', dirname)
kw.setdefault('sep', self.sep)
obj.strpath = normpath('{dirname}{sep}{basename}'.format(**kw))
return obj
def _getbyspec(self, spec: str) -> list[str]:
@ -685,28 +685,28 @@ class LocalPath:
res = []
parts = self.strpath.split(self.sep)
args = filter(None, spec.split(","))
args = filter(None, spec.split(','))
for name in args:
if name == "drive":
if name == 'drive':
res.append(parts[0])
elif name == "dirname":
elif name == 'dirname':
res.append(self.sep.join(parts[:-1]))
else:
basename = parts[-1]
if name == "basename":
if name == 'basename':
res.append(basename)
else:
i = basename.rfind(".")
i = basename.rfind('.')
if i == -1:
purebasename, ext = basename, ""
purebasename, ext = basename, ''
else:
purebasename, ext = basename[:i], basename[i:]
if name == "purebasename":
if name == 'purebasename':
res.append(purebasename)
elif name == "ext":
elif name == 'ext':
res.append(ext)
else:
raise ValueError("invalid part specification %r" % name)
raise ValueError('invalid part specification %r' % name)
return res
def dirpath(self, *args, **kwargs):
@ -717,7 +717,7 @@ class LocalPath:
if args:
path = path.join(*args)
return path
return self.new(basename="").join(*args, **kwargs)
return self.new(basename='').join(*args, **kwargs)
def join(self, *args: os.PathLike[str], abs: bool = False) -> LocalPath:
"""Return a new path by appending all 'args' as path
@ -736,20 +736,20 @@ class LocalPath:
break
newargs.insert(0, arg)
# special case for when we have e.g. strpath == "/"
actual_sep = "" if strpath.endswith(sep) else sep
actual_sep = '' if strpath.endswith(sep) else sep
for arg in strargs:
arg = arg.strip(sep)
if iswin32:
# allow unix style paths even on windows.
arg = arg.strip("/")
arg = arg.replace("/", sep)
arg = arg.strip('/')
arg = arg.replace('/', sep)
strpath = strpath + actual_sep + arg
actual_sep = sep
obj = object.__new__(self.__class__)
obj.strpath = normpath(strpath)
return obj
def open(self, mode="r", ensure=False, encoding=None):
def open(self, mode='r', ensure=False, encoding=None):
"""Return an opened file with the given mode.
If ensure is True, create parent directories if needed.
@ -797,15 +797,15 @@ class LocalPath:
if not kw:
return exists(self.strpath)
if len(kw) == 1:
if "dir" in kw:
return not kw["dir"] ^ isdir(self.strpath)
if "file" in kw:
return not kw["file"] ^ isfile(self.strpath)
if 'dir' in kw:
return not kw['dir'] ^ isdir(self.strpath)
if 'file' in kw:
return not kw['file'] ^ isfile(self.strpath)
if not kw:
kw = {"exists": 1}
kw = {'exists': 1}
return Checkers(self)._evaluate(kw)
_patternchars = set("*?[" + os.sep)
_patternchars = set('*?[' + os.sep)
def listdir(self, fil=None, sort=None):
"""List directory contents, possibly filter by the given fil func
@ -882,7 +882,7 @@ class LocalPath:
def dump(self, obj, bin=1):
"""Pickle object into path location"""
f = self.open("wb")
f = self.open('wb')
import pickle
try:
@ -902,7 +902,7 @@ class LocalPath:
"""
if ensure:
self.dirpath().ensure(dir=1)
with self.open("wb") as f:
with self.open('wb') as f:
f.write(data)
def write_text(self, data, encoding, ensure=False):
@ -911,18 +911,18 @@ class LocalPath:
"""
if ensure:
self.dirpath().ensure(dir=1)
with self.open("w", encoding=encoding) as f:
with self.open('w', encoding=encoding) as f:
f.write(data)
def write(self, data, mode="w", ensure=False):
def write(self, data, mode='w', ensure=False):
"""Write data into path. If ensure is True create
missing parent directories.
"""
if ensure:
self.dirpath().ensure(dir=1)
if "b" in mode:
if 'b' in mode:
if not isinstance(data, bytes):
raise ValueError("can only process bytes")
raise ValueError('can only process bytes')
else:
if not isinstance(data, str):
if not isinstance(data, bytes):
@ -957,12 +957,12 @@ class LocalPath:
then the path is forced to be a directory path.
"""
p = self.join(*args)
if kwargs.get("dir", 0):
if kwargs.get('dir', 0):
return p._ensuredirs()
else:
p.dirpath()._ensuredirs()
if not p.check(file=1):
p.open("wb").close()
p.open('wb').close()
return p
@overload
@ -1033,7 +1033,7 @@ class LocalPath:
return self.stat().atime
def __repr__(self):
return "local(%r)" % self.strpath
return 'local(%r)' % self.strpath
def __str__(self):
"""Return string representation of the Path."""
@ -1045,7 +1045,7 @@ class LocalPath:
if rec is True perform recursively.
"""
if not isinstance(mode, int):
raise TypeError(f"mode {mode!r} must be an integer")
raise TypeError(f'mode {mode!r} must be an integer')
if rec:
for x in self.visit(rec=rec):
error.checked_call(os.chmod, str(x), mode)
@ -1059,7 +1059,7 @@ class LocalPath:
pkgpath = None
for parent in self.parts(reverse=True):
if parent.isdir():
if not parent.join("__init__.py").exists():
if not parent.join('__init__.py').exists():
break
if not isimportable(parent.basename):
break
@ -1069,7 +1069,7 @@ class LocalPath:
def _ensuresyspath(self, ensuremode, path):
if ensuremode:
s = str(path)
if ensuremode == "append":
if ensuremode == 'append':
if s not in sys.path:
sys.path.append(s)
else:
@ -1100,7 +1100,7 @@ class LocalPath:
if not self.check():
raise error.ENOENT(self)
if ensuresyspath == "importlib":
if ensuresyspath == 'importlib':
if modname is None:
modname = self.purebasename
spec = importlib.util.spec_from_file_location(modname, str(self))
@ -1115,10 +1115,10 @@ class LocalPath:
pkgpath = self.pypkgpath()
if pkgpath is not None:
pkgroot = pkgpath.dirpath()
names = self.new(ext="").relto(pkgroot).split(self.sep)
if names[-1] == "__init__":
names = self.new(ext='').relto(pkgroot).split(self.sep)
if names[-1] == '__init__':
names.pop()
modname = ".".join(names)
modname = '.'.join(names)
else:
pkgroot = self.dirpath()
modname = self.purebasename
@ -1126,25 +1126,25 @@ class LocalPath:
self._ensuresyspath(ensuresyspath, pkgroot)
__import__(modname)
mod = sys.modules[modname]
if self.basename == "__init__.py":
if self.basename == '__init__.py':
return mod # we don't check anything as we might
# be in a namespace package ... too icky to check
modfile = mod.__file__
assert modfile is not None
if modfile[-4:] in (".pyc", ".pyo"):
if modfile[-4:] in ('.pyc', '.pyo'):
modfile = modfile[:-1]
elif modfile.endswith("$py.class"):
modfile = modfile[:-9] + ".py"
if modfile.endswith(os.sep + "__init__.py"):
if self.basename != "__init__.py":
elif modfile.endswith('$py.class'):
modfile = modfile[:-9] + '.py'
if modfile.endswith(os.sep + '__init__.py'):
if self.basename != '__init__.py':
modfile = modfile[:-12]
try:
issame = self.samefile(modfile)
except error.ENOENT:
issame = False
if not issame:
ignore = os.getenv("PY_IGNORE_IMPORTMISMATCH")
if ignore != "1":
ignore = os.getenv('PY_IGNORE_IMPORTMISMATCH')
if ignore != '1':
raise self.ImportMismatchError(modname, modfile, self)
return mod
else:
@ -1158,7 +1158,7 @@ class LocalPath:
mod.__file__ = str(self)
sys.modules[modname] = mod
try:
with open(str(self), "rb") as f:
with open(str(self), 'rb') as f:
exec(f.read(), mod.__dict__)
except BaseException:
del sys.modules[modname]
@ -1173,8 +1173,8 @@ class LocalPath:
from subprocess import PIPE
from subprocess import Popen
popen_opts.pop("stdout", None)
popen_opts.pop("stderr", None)
popen_opts.pop('stdout', None)
popen_opts.pop('stderr', None)
proc = Popen(
[str(self)] + [str(arg) for arg in argv],
**popen_opts,
@ -1214,23 +1214,23 @@ class LocalPath:
else:
if paths is None:
if iswin32:
paths = os.environ["Path"].split(";")
if "" not in paths and "." not in paths:
paths.append(".")
paths = os.environ['Path'].split(';')
if '' not in paths and '.' not in paths:
paths.append('.')
try:
systemroot = os.environ["SYSTEMROOT"]
systemroot = os.environ['SYSTEMROOT']
except KeyError:
pass
else:
paths = [
path.replace("%SystemRoot%", systemroot) for path in paths
path.replace('%SystemRoot%', systemroot) for path in paths
]
else:
paths = os.environ["PATH"].split(":")
paths = os.environ['PATH'].split(':')
tryadd = []
if iswin32:
tryadd += os.environ["PATHEXT"].split(os.pathsep)
tryadd.append("")
tryadd += os.environ['PATHEXT'].split(os.pathsep)
tryadd.append('')
for x in paths:
for addext in tryadd:
@ -1248,10 +1248,10 @@ class LocalPath:
@classmethod
def _gethomedir(cls):
try:
x = os.environ["HOME"]
x = os.environ['HOME']
except KeyError:
try:
x = os.environ["HOMEDRIVE"] + os.environ["HOMEPATH"]
x = os.environ['HOMEDRIVE'] + os.environ['HOMEPATH']
except KeyError:
return None
return cls(x)
@ -1288,7 +1288,7 @@ class LocalPath:
@classmethod
def make_numbered_dir(
cls, prefix="session-", rootdir=None, keep=3, lock_timeout=172800
cls, prefix='session-', rootdir=None, keep=3, lock_timeout=172800,
): # two days
"""Return unique directory with a number greater than the current
maximum one. The number is assumed to start directly after prefix.
@ -1306,21 +1306,21 @@ class LocalPath:
nbasename = path.basename.lower()
if nbasename.startswith(nprefix):
try:
return int(nbasename[len(nprefix) :])
return int(nbasename[len(nprefix):])
except ValueError:
pass
def create_lockfile(path):
"""Exclusively create lockfile. Throws when failed"""
mypid = os.getpid()
lockfile = path.join(".lock")
if hasattr(lockfile, "mksymlinkto"):
lockfile = path.join('.lock')
if hasattr(lockfile, 'mksymlinkto'):
lockfile.mksymlinkto(str(mypid))
else:
fd = error.checked_call(
os.open, str(lockfile), os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644
os.open, str(lockfile), os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644,
)
with os.fdopen(fd, "w") as f:
with os.fdopen(fd, 'w') as f:
f.write(str(mypid))
return lockfile
@ -1380,7 +1380,7 @@ class LocalPath:
except error.Error:
pass
garbage_prefix = prefix + "garbage-"
garbage_prefix = prefix + 'garbage-'
def is_garbage(path):
"""Check if path denotes directory scheduled for removal"""
@ -1428,15 +1428,15 @@ class LocalPath:
# make link...
try:
username = os.environ["USER"] # linux, et al
username = os.environ['USER'] # linux, et al
except KeyError:
try:
username = os.environ["USERNAME"] # windows
username = os.environ['USERNAME'] # windows
except KeyError:
username = "current"
username = 'current'
src = str(udir)
dest = src[: src.rfind("-")] + "-" + username
dest = src[: src.rfind('-')] + '-' + username
try:
os.unlink(dest)
except OSError:
@ -1466,9 +1466,9 @@ def copystat(src, dest):
def copychunked(src, dest):
chunksize = 524288 # half a meg of bytes
fsrc = src.open("rb")
fsrc = src.open('rb')
try:
fdest = dest.open("wb")
fdest = dest.open('wb')
try:
while 1:
buf = fsrc.read(chunksize)
@ -1482,8 +1482,8 @@ def copychunked(src, dest):
def isimportable(name):
if name and (name[0].isalpha() or name[0] == "_"):
name = name.replace("_", "")
if name and (name[0].isalpha() or name[0] == '_'):
name = name.replace('_', '')
return not name or name.isalnum()

View file

@ -1,5 +1,6 @@
# file generated by setuptools_scm
# don't change, don't track in version control
from __future__ import annotations
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple, Union

View file

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs
"""Support for presenting detailed information in failing assertions."""
from __future__ import annotations
import sys
from typing import Any
from typing import Generator
@ -22,34 +24,34 @@ if TYPE_CHECKING:
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group = parser.getgroup('debugconfig')
group.addoption(
"--assert",
action="store",
dest="assertmode",
choices=("rewrite", "plain"),
default="rewrite",
metavar="MODE",
'--assert',
action='store',
dest='assertmode',
choices=('rewrite', 'plain'),
default='rewrite',
metavar='MODE',
help=(
"Control assertion debugging tools.\n"
'Control assertion debugging tools.\n'
"'plain' performs no assertion debugging.\n"
"'rewrite' (the default) rewrites assert statements in test modules"
" on import to provide assert expression information."
' on import to provide assert expression information.'
),
)
parser.addini(
"enable_assertion_pass_hook",
type="bool",
'enable_assertion_pass_hook',
type='bool',
default=False,
help="Enables the pytest_assertion_pass hook. "
"Make sure to delete any previously generated pyc cache files.",
help='Enables the pytest_assertion_pass hook. '
'Make sure to delete any previously generated pyc cache files.',
)
Config._add_verbosity_ini(
parser,
Config.VERBOSITY_ASSERTIONS,
help=(
"Specify a verbosity level for assertions, overriding the main level. "
"Higher levels will provide more detailed explanation when an assertion fails."
'Specify a verbosity level for assertions, overriding the main level. '
'Higher levels will provide more detailed explanation when an assertion fails.'
),
)
@ -67,7 +69,7 @@ def register_assert_rewrite(*names: str) -> None:
"""
for name in names:
if not isinstance(name, str):
msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
msg = 'expected module names as *args, got {0} instead' # type: ignore[unreachable]
raise TypeError(msg.format(repr(names)))
for hook in sys.meta_path:
if isinstance(hook, rewrite.AssertionRewritingHook):
@ -92,16 +94,16 @@ class AssertionState:
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.hook: Optional[rewrite.AssertionRewritingHook] = None
self.trace = config.trace.root.get('assertion')
self.hook: rewrite.AssertionRewritingHook | None = None
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails."""
config.stash[assertstate_key] = AssertionState(config, "rewrite")
config.stash[assertstate_key] = AssertionState(config, 'rewrite')
config.stash[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook)
config.stash[assertstate_key].trace("installed rewrite import hook")
config.stash[assertstate_key].trace('installed rewrite import hook')
def undo() -> None:
hook = config.stash[assertstate_key].hook
@ -112,7 +114,7 @@ def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
return hook
def pytest_collection(session: "Session") -> None:
def pytest_collection(session: Session) -> None:
# This hook is only called when test modules are collected
# so for example not in the managing process of pytest-xdist
# (which does not collect test modules).
@ -132,7 +134,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
"""
ihook = item.ihook
def callbinrepr(op, left: object, right: object) -> Optional[str]:
def callbinrepr(op, left: object, right: object) -> str | None:
"""Call the pytest_assertrepr_compare hook and prepare the result.
This uses the first result from the hook and then ensures the
@ -148,15 +150,15 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
pretty printing.
"""
hook_result = ihook.pytest_assertrepr_compare(
config=item.config, op=op, left=left, right=right
config=item.config, op=op, left=left, right=right,
)
for new_expl in hook_result:
if new_expl:
new_expl = truncate.truncate_if_required(new_expl, item)
new_expl = [line.replace("\n", "\\n") for line in new_expl]
res = "\n~".join(new_expl)
if item.config.getvalue("assertmode") == "rewrite":
res = res.replace("%", "%%")
new_expl = [line.replace('\n', '\\n') for line in new_expl]
res = '\n~'.join(new_expl)
if item.config.getvalue('assertmode') == 'rewrite':
res = res.replace('%', '%%')
return res
return None
@ -178,7 +180,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
util._config = None
def pytest_sessionfinish(session: "Session") -> None:
def pytest_sessionfinish(session: Session) -> None:
assertstate = session.config.stash.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
@ -186,6 +188,6 @@ def pytest_sessionfinish(session: "Session") -> None:
def pytest_assertrepr_compare(
config: Config, op: str, left: Any, right: Any
) -> Optional[List[str]]:
config: Config, op: str, left: Any, right: Any,
) -> list[str] | None:
return util.assertrepr_compare(config=config, op=op, left=left, right=right)

View file

@ -3,6 +3,7 @@
Current default behaviour is to truncate assertion explanations at
terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI.
"""
from __future__ import annotations
from typing import List
from typing import Optional
@ -18,8 +19,8 @@ USAGE_MSG = "use '-vv' to show"
def truncate_if_required(
explanation: List[str], item: Item, max_length: Optional[int] = None
) -> List[str]:
explanation: list[str], item: Item, max_length: int | None = None,
) -> list[str]:
"""Truncate this assertion explanation if the given test item is eligible."""
if _should_truncate_item(item):
return _truncate_explanation(explanation)
@ -33,10 +34,10 @@ def _should_truncate_item(item: Item) -> bool:
def _truncate_explanation(
input_lines: List[str],
max_lines: Optional[int] = None,
max_chars: Optional[int] = None,
) -> List[str]:
input_lines: list[str],
max_lines: int | None = None,
max_chars: int | None = None,
) -> list[str]:
"""Truncate given list of strings that makes up the assertion explanation.
Truncates to either 8 lines, or 640 characters - whichever the input reaches
@ -49,7 +50,7 @@ def _truncate_explanation(
max_chars = DEFAULT_MAX_CHARS
# Check if truncation required
input_char_count = len("".join(input_lines))
input_char_count = len(''.join(input_lines))
# The length of the truncation explanation depends on the number of lines
# removed but is at least 68 characters:
# The real value is
@ -67,17 +68,17 @@ def _truncate_explanation(
# The truncation explanation add two lines to the output
tolerable_max_lines = max_lines + 2
if (
len(input_lines) <= tolerable_max_lines
and input_char_count <= tolerable_max_chars
len(input_lines) <= tolerable_max_lines and
input_char_count <= tolerable_max_chars
):
return input_lines
# Truncate first to max_lines, and then truncate to max_chars if necessary
truncated_explanation = input_lines[:max_lines]
truncated_char = True
# We reevaluate the need to truncate chars following removal of some lines
if len("".join(truncated_explanation)) > tolerable_max_chars:
if len(''.join(truncated_explanation)) > tolerable_max_chars:
truncated_explanation = _truncate_by_char_count(
truncated_explanation, max_chars
truncated_explanation, max_chars,
)
else:
truncated_char = False
@ -85,22 +86,22 @@ def _truncate_explanation(
truncated_line_count = len(input_lines) - len(truncated_explanation)
if truncated_explanation[-1]:
# Add ellipsis and take into account part-truncated final line
truncated_explanation[-1] = truncated_explanation[-1] + "..."
truncated_explanation[-1] = truncated_explanation[-1] + '...'
if truncated_char:
# It's possible that we did not remove any char from this line
truncated_line_count += 1
else:
# Add proper ellipsis when we were able to fit a full line exactly
truncated_explanation[-1] = "..."
truncated_explanation[-1] = '...'
return [
*truncated_explanation,
"",
f"...Full output truncated ({truncated_line_count} line"
'',
f'...Full output truncated ({truncated_line_count} line'
f"{'' if truncated_line_count == 1 else 's'} hidden), {USAGE_MSG}",
]
def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
def _truncate_by_char_count(input_lines: list[str], max_chars: int) -> list[str]:
# Find point at which input length exceeds total allowed length
iterated_char_count = 0
for iterated_index, input_line in enumerate(input_lines):

View file

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs
"""Utilities for assertion debugging."""
from __future__ import annotations
import collections.abc
import os
import pprint
@ -15,8 +17,8 @@ from typing import Protocol
from typing import Sequence
from unicodedata import normalize
from _pytest import outcomes
import _pytest._code
from _pytest import outcomes
from _pytest._io.pprint import PrettyPrinter
from _pytest._io.saferepr import saferepr
from _pytest._io.saferepr import saferepr_unlimited
@ -27,18 +29,18 @@ from _pytest.config import Config
# interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the
# DebugInterpreter.
_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None
_reprcompare: Callable[[str, object, object], str | None] | None = None
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass: Optional[Callable[[int, str, str], None]] = None
_assertion_pass: Callable[[int, str, str], None] | None = None
# Config object which is assigned during pytest_runtest_protocol.
_config: Optional[Config] = None
_config: Config | None = None
class _HighlightFunc(Protocol):
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
def __call__(self, source: str, lexer: Literal['diff', 'python'] = 'python') -> str:
"""Apply highlighting to the given source."""
@ -54,27 +56,27 @@ def format_explanation(explanation: str) -> str:
"""
lines = _split_explanation(explanation)
result = _format_lines(lines)
return "\n".join(result)
return '\n'.join(result)
def _split_explanation(explanation: str) -> List[str]:
def _split_explanation(explanation: str) -> list[str]:
r"""Return a list of individual lines in the explanation.
This will return a list of lines split on '\n{', '\n}' and '\n~'.
Any other newlines will be escaped and appear in the line as the
literal '\n' characters.
"""
raw_lines = (explanation or "").split("\n")
raw_lines = (explanation or '').split('\n')
lines = [raw_lines[0]]
for values in raw_lines[1:]:
if values and values[0] in ["{", "}", "~", ">"]:
if values and values[0] in ['{', '}', '~', '>']:
lines.append(values)
else:
lines[-1] += "\\n" + values
lines[-1] += '\\n' + values
return lines
def _format_lines(lines: Sequence[str]) -> List[str]:
def _format_lines(lines: Sequence[str]) -> list[str]:
"""Format the individual lines.
This will replace the '{', '}' and '~' characters of our mini formatting
@ -87,24 +89,24 @@ def _format_lines(lines: Sequence[str]) -> List[str]:
stack = [0]
stackcnt = [0]
for line in lines[1:]:
if line.startswith("{"):
if line.startswith('{'):
if stackcnt[-1]:
s = "and "
s = 'and '
else:
s = "where "
s = 'where '
stack.append(len(result))
stackcnt[-1] += 1
stackcnt.append(0)
result.append(" +" + " " * (len(stack) - 1) + s + line[1:])
elif line.startswith("}"):
result.append(' +' + ' ' * (len(stack) - 1) + s + line[1:])
elif line.startswith('}'):
stack.pop()
stackcnt.pop()
result[stack[-1]] += line[1:]
else:
assert line[0] in ["~", ">"]
assert line[0] in ['~', '>']
stack[-1] += 1
indent = len(stack) if line.startswith("~") else len(stack) - 1
result.append(" " * indent + line[1:])
indent = len(stack) if line.startswith('~') else len(stack) - 1
result.append(' ' * indent + line[1:])
assert len(stack) == 1
return result
@ -126,15 +128,15 @@ def isset(x: Any) -> bool:
def isnamedtuple(obj: Any) -> bool:
return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None
return isinstance(obj, tuple) and getattr(obj, '_fields', None) is not None
def isdatacls(obj: Any) -> bool:
return getattr(obj, "__dataclass_fields__", None) is not None
return getattr(obj, '__dataclass_fields__', None) is not None
def isattrs(obj: Any) -> bool:
return getattr(obj, "__attrs_attrs__", None) is not None
return getattr(obj, '__attrs_attrs__', None) is not None
def isiterable(obj: Any) -> bool:
@ -156,28 +158,28 @@ def has_default_eq(
for dataclasses the default co_filename is <string>, for attrs class, the __eq__ should contain "attrs eq generated"
"""
# inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68
if hasattr(obj.__eq__, "__code__") and hasattr(obj.__eq__.__code__, "co_filename"):
if hasattr(obj.__eq__, '__code__') and hasattr(obj.__eq__.__code__, 'co_filename'):
code_filename = obj.__eq__.__code__.co_filename
if isattrs(obj):
return "attrs generated eq" in code_filename
return 'attrs generated eq' in code_filename
return code_filename == "<string>" # data class
return code_filename == '<string>' # data class
return True
def assertrepr_compare(
config, op: str, left: Any, right: Any, use_ascii: bool = False
) -> Optional[List[str]]:
config, op: str, left: Any, right: Any, use_ascii: bool = False,
) -> list[str] | None:
"""Return specialised explanations for some operators/operands."""
verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
# Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier.
# See issue #3246.
use_ascii = (
isinstance(left, str)
and isinstance(right, str)
and normalize("NFD", left) == normalize("NFD", right)
isinstance(left, str) and
isinstance(right, str) and
normalize('NFD', left) == normalize('NFD', right)
)
if verbose > 1:
@ -193,29 +195,29 @@ def assertrepr_compare(
left_repr = saferepr(left, maxsize=maxsize, use_ascii=use_ascii)
right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii)
summary = f"{left_repr} {op} {right_repr}"
summary = f'{left_repr} {op} {right_repr}'
highlighter = config.get_terminal_writer()._highlight
explanation = None
try:
if op == "==":
if op == '==':
explanation = _compare_eq_any(left, right, highlighter, verbose)
elif op == "not in":
elif op == 'not in':
if istext(left) and istext(right):
explanation = _notin_text(left, right, verbose)
elif op == "!=":
elif op == '!=':
if isset(left) and isset(right):
explanation = ["Both sets are equal"]
elif op == ">=":
explanation = ['Both sets are equal']
elif op == '>=':
if isset(left) and isset(right):
explanation = _compare_gte_set(left, right, highlighter, verbose)
elif op == "<=":
elif op == '<=':
if isset(left) and isset(right):
explanation = _compare_lte_set(left, right, highlighter, verbose)
elif op == ">":
elif op == '>':
if isset(left) and isset(right):
explanation = _compare_gt_set(left, right, highlighter, verbose)
elif op == "<":
elif op == '<':
if isset(left) and isset(right):
explanation = _compare_lt_set(left, right, highlighter, verbose)
@ -223,23 +225,23 @@ def assertrepr_compare(
raise
except Exception:
explanation = [
"(pytest_assertion plugin: representation of details failed: {}.".format(
_pytest._code.ExceptionInfo.from_current()._getreprcrash()
'(pytest_assertion plugin: representation of details failed: {}.'.format(
_pytest._code.ExceptionInfo.from_current()._getreprcrash(),
),
" Probably an object has a faulty __repr__.)",
' Probably an object has a faulty __repr__.)',
]
if not explanation:
return None
if explanation[0] != "":
explanation = ["", *explanation]
if explanation[0] != '':
explanation = ['', *explanation]
return [summary, *explanation]
def _compare_eq_any(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0
) -> List[str]:
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0,
) -> list[str]:
explanation = []
if istext(left) and istext(right):
explanation = _diff_text(left, right, verbose)
@ -274,7 +276,7 @@ def _compare_eq_any(
return explanation
def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
def _diff_text(left: str, right: str, verbose: int = 0) -> list[str]:
"""Return the explanation for the diff between text.
Unless --verbose is used this will skip leading and trailing
@ -282,7 +284,7 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
"""
from difflib import ndiff
explanation: List[str] = []
explanation: list[str] = []
if verbose < 1:
i = 0 # just in case left or right has zero length
@ -292,7 +294,7 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
if i > 42:
i -= 10 # Provide some context
explanation = [
"Skipping %s identical leading characters in diff, use -v to show" % i
'Skipping %s identical leading characters in diff, use -v to show' % i,
]
left = left[i:]
right = right[i:]
@ -303,8 +305,8 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
if i > 42:
i -= 10 # Provide some context
explanation += [
f"Skipping {i} identical trailing "
"characters in diff, use -v to show"
f'Skipping {i} identical trailing '
'characters in diff, use -v to show',
]
left = left[:-i]
right = right[:-i]
@ -312,11 +314,11 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
if left.isspace() or right.isspace():
left = repr(str(left))
right = repr(str(right))
explanation += ["Strings contain only whitespace, escaping them using repr()"]
explanation += ['Strings contain only whitespace, escaping them using repr()']
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation += [
line.strip("\n")
line.strip('\n')
for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
]
return explanation
@ -327,26 +329,26 @@ def _compare_eq_iterable(
right: Iterable[Any],
highligher: _HighlightFunc,
verbose: int = 0,
) -> List[str]:
) -> list[str]:
if verbose <= 0 and not running_on_ci():
return ["Use -v to get more diff"]
return ['Use -v to get more diff']
# dynamic import to speedup pytest
import difflib
left_formatting = PrettyPrinter().pformat(left).splitlines()
right_formatting = PrettyPrinter().pformat(right).splitlines()
explanation = ["", "Full diff:"]
explanation = ['', 'Full diff:']
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend(
highligher(
"\n".join(
'\n'.join(
line.rstrip()
for line in difflib.ndiff(right_formatting, left_formatting)
),
lexer="diff",
).splitlines()
lexer='diff',
).splitlines(),
)
return explanation
@ -356,9 +358,9 @@ def _compare_eq_sequence(
right: Sequence[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> List[str]:
) -> list[str]:
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation: List[str] = []
explanation: list[str] = []
len_left = len(left)
len_right = len(right)
for i in range(min(len_left, len_right)):
@ -372,15 +374,15 @@ def _compare_eq_sequence(
# 102
# >>> s[0:1]
# b'f'
left_value = left[i : i + 1]
right_value = right[i : i + 1]
left_value = left[i: i + 1]
right_value = right[i: i + 1]
else:
left_value = left[i]
right_value = right[i]
explanation.append(
f"At index {i} diff:"
f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}"
f'At index {i} diff:'
f' {highlighter(repr(left_value))} != {highlighter(repr(right_value))}',
)
break
@ -393,21 +395,21 @@ def _compare_eq_sequence(
len_diff = len_left - len_right
if len_diff:
if len_diff > 0:
dir_with_more = "Left"
dir_with_more = 'Left'
extra = saferepr(left[len_right])
else:
len_diff = 0 - len_diff
dir_with_more = "Right"
dir_with_more = 'Right'
extra = saferepr(right[len_left])
if len_diff == 1:
explanation += [
f"{dir_with_more} contains one more item: {highlighter(extra)}"
f'{dir_with_more} contains one more item: {highlighter(extra)}',
]
else:
explanation += [
"%s contains %d more items, first extra item: %s"
% (dir_with_more, len_diff, highlighter(extra))
'%s contains %d more items, first extra item: %s'
% (dir_with_more, len_diff, highlighter(extra)),
]
return explanation
@ -417,10 +419,10 @@ def _compare_eq_set(
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> List[str]:
) -> list[str]:
explanation = []
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
explanation.extend(_set_one_sided_diff('left', left, right, highlighter))
explanation.extend(_set_one_sided_diff('right', right, left, highlighter))
return explanation
@ -429,10 +431,10 @@ def _compare_gt_set(
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> List[str]:
) -> list[str]:
explanation = _compare_gte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return ['Both sets are equal']
return explanation
@ -441,10 +443,10 @@ def _compare_lt_set(
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> List[str]:
) -> list[str]:
explanation = _compare_lte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return ['Both sets are equal']
return explanation
@ -453,8 +455,8 @@ def _compare_gte_set(
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> List[str]:
return _set_one_sided_diff("right", right, left, highlighter)
) -> list[str]:
return _set_one_sided_diff('right', right, left, highlighter)
def _compare_lte_set(
@ -462,8 +464,8 @@ def _compare_lte_set(
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> List[str]:
return _set_one_sided_diff("left", left, right, highlighter)
) -> list[str]:
return _set_one_sided_diff('left', left, right, highlighter)
def _set_one_sided_diff(
@ -471,11 +473,11 @@ def _set_one_sided_diff(
set1: AbstractSet[Any],
set2: AbstractSet[Any],
highlighter: _HighlightFunc,
) -> List[str]:
) -> list[str]:
explanation = []
diff = set1 - set2
if diff:
explanation.append(f"Extra items in the {posn} set:")
explanation.append(f'Extra items in the {posn} set:')
for item in diff:
explanation.append(highlighter(saferepr(item)))
return explanation
@ -486,52 +488,52 @@ def _compare_eq_dict(
right: Mapping[Any, Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> List[str]:
explanation: List[str] = []
) -> list[str]:
explanation: list[str] = []
set_left = set(left)
set_right = set(right)
common = set_left.intersection(set_right)
same = {k: left[k] for k in common if left[k] == right[k]}
if same and verbose < 2:
explanation += ["Omitting %s identical items, use -vv to show" % len(same)]
explanation += ['Omitting %s identical items, use -vv to show' % len(same)]
elif same:
explanation += ["Common items:"]
explanation += ['Common items:']
explanation += highlighter(pprint.pformat(same)).splitlines()
diff = {k for k in common if left[k] != right[k]}
if diff:
explanation += ["Differing items:"]
explanation += ['Differing items:']
for k in diff:
explanation += [
highlighter(saferepr({k: left[k]}))
+ " != "
+ highlighter(saferepr({k: right[k]}))
highlighter(saferepr({k: left[k]})) +
' != ' +
highlighter(saferepr({k: right[k]})),
]
extra_left = set_left - set_right
len_extra_left = len(extra_left)
if len_extra_left:
explanation.append(
"Left contains %d more item%s:"
% (len_extra_left, "" if len_extra_left == 1 else "s")
'Left contains %d more item%s:'
% (len_extra_left, '' if len_extra_left == 1 else 's'),
)
explanation.extend(
highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines()
highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines(),
)
extra_right = set_right - set_left
len_extra_right = len(extra_right)
if len_extra_right:
explanation.append(
"Right contains %d more item%s:"
% (len_extra_right, "" if len_extra_right == 1 else "s")
'Right contains %d more item%s:'
% (len_extra_right, '' if len_extra_right == 1 else 's'),
)
explanation.extend(
highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines()
highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines(),
)
return explanation
def _compare_eq_cls(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int
) -> List[str]:
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int,
) -> list[str]:
if not has_default_eq(left):
return []
if isdatacls(left):
@ -541,13 +543,13 @@ def _compare_eq_cls(
fields_to_check = [info.name for info in all_fields if info.compare]
elif isattrs(left):
all_fields = left.__attrs_attrs__
fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
fields_to_check = [field.name for field in all_fields if getattr(field, 'eq')]
elif isnamedtuple(left):
fields_to_check = left._fields
else:
assert False
indent = " "
indent = ' '
same = []
diff = []
for field in fields_to_check:
@ -558,46 +560,46 @@ def _compare_eq_cls(
explanation = []
if same or diff:
explanation += [""]
explanation += ['']
if same and verbose < 2:
explanation.append("Omitting %s identical items, use -vv to show" % len(same))
explanation.append('Omitting %s identical items, use -vv to show' % len(same))
elif same:
explanation += ["Matching attributes:"]
explanation += ['Matching attributes:']
explanation += highlighter(pprint.pformat(same)).splitlines()
if diff:
explanation += ["Differing attributes:"]
explanation += ['Differing attributes:']
explanation += highlighter(pprint.pformat(diff)).splitlines()
for field in diff:
field_left = getattr(left, field)
field_right = getattr(right, field)
explanation += [
"",
f"Drill down into differing attribute {field}:",
f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}",
'',
f'Drill down into differing attribute {field}:',
f'{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}',
]
explanation += [
indent + line
for line in _compare_eq_any(
field_left, field_right, highlighter, verbose
field_left, field_right, highlighter, verbose,
)
]
return explanation
def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]:
def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]:
index = text.find(term)
head = text[:index]
tail = text[index + len(term) :]
tail = text[index + len(term):]
correct_text = head + tail
diff = _diff_text(text, correct_text, verbose)
newdiff = ["%s is contained here:" % saferepr(term, maxsize=42)]
newdiff = ['%s is contained here:' % saferepr(term, maxsize=42)]
for line in diff:
if line.startswith("Skipping"):
if line.startswith('Skipping'):
continue
if line.startswith("- "):
if line.startswith('- '):
continue
if line.startswith("+ "):
newdiff.append(" " + line[2:])
if line.startswith('+ '):
newdiff.append(' ' + line[2:])
else:
newdiff.append(line)
return newdiff
@ -605,5 +607,5 @@ def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]:
def running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"]
env_vars = ['CI', 'BUILD_NUMBER']
return any(var in os.environ for var in env_vars)

View file

@ -2,6 +2,8 @@
"""Implementation of the cache provider."""
# This plugin was not named "cache" to avoid conflicts with the external
# pytest-cache version.
from __future__ import annotations
import dataclasses
import json
import os
@ -15,9 +17,6 @@ from typing import Optional
from typing import Set
from typing import Union
from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.config import Config
@ -32,6 +31,10 @@ from _pytest.nodes import Directory
from _pytest.nodes import File
from _pytest.reports import TestReport
from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
README_CONTENT = """\
# pytest cache directory #
@ -61,27 +64,27 @@ class Cache:
_config: Config = dataclasses.field(repr=False)
# Sub-directory under cache-dir for directories created by `mkdir()`.
_CACHE_PREFIX_DIRS = "d"
_CACHE_PREFIX_DIRS = 'd'
# Sub-directory under cache-dir for values created by `set()`.
_CACHE_PREFIX_VALUES = "v"
_CACHE_PREFIX_VALUES = 'v'
def __init__(
self, cachedir: Path, config: Config, *, _ispytest: bool = False
self, cachedir: Path, config: Config, *, _ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self._cachedir = cachedir
self._config = config
@classmethod
def for_config(cls, config: Config, *, _ispytest: bool = False) -> "Cache":
def for_config(cls, config: Config, *, _ispytest: bool = False) -> Cache:
"""Create the Cache instance for a Config.
:meta private:
"""
check_ispytest(_ispytest)
cachedir = cls.cache_dir_from_config(config, _ispytest=True)
if config.getoption("cacheclear") and cachedir.is_dir():
if config.getoption('cacheclear') and cachedir.is_dir():
cls.clear_cache(cachedir, _ispytest=True)
return cls(cachedir, config, _ispytest=True)
@ -104,7 +107,7 @@ class Cache:
:meta private:
"""
check_ispytest(_ispytest)
return resolve_from_str(config.getini("cache_dir"), config.rootpath)
return resolve_from_str(config.getini('cache_dir'), config.rootpath)
def warn(self, fmt: str, *, _ispytest: bool = False, **args: object) -> None:
"""Issue a cache warning.
@ -138,7 +141,7 @@ class Cache:
"""
path = Path(name)
if len(path.parts) > 1:
raise ValueError("name is not allowed to contain path separators")
raise ValueError('name is not allowed to contain path separators')
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
res.mkdir(exist_ok=True, parents=True)
return res
@ -160,7 +163,7 @@ class Cache:
"""
path = self._getvaluepath(key)
try:
with path.open("r", encoding="UTF-8") as f:
with path.open('r', encoding='UTF-8') as f:
return json.load(f)
except (ValueError, OSError):
return default
@ -184,7 +187,7 @@ class Cache:
path.parent.mkdir(exist_ok=True, parents=True)
except OSError as exc:
self.warn(
f"could not create cache path {path}: {exc}",
f'could not create cache path {path}: {exc}',
_ispytest=True,
)
return
@ -192,10 +195,10 @@ class Cache:
self._ensure_supporting_files()
data = json.dumps(value, ensure_ascii=False, indent=2)
try:
f = path.open("w", encoding="UTF-8")
f = path.open('w', encoding='UTF-8')
except OSError as exc:
self.warn(
f"cache could not write path {path}: {exc}",
f'cache could not write path {path}: {exc}',
_ispytest=True,
)
else:
@ -204,25 +207,25 @@ class Cache:
def _ensure_supporting_files(self) -> None:
"""Create supporting files in the cache dir that are not really part of the cache."""
readme_path = self._cachedir / "README.md"
readme_path.write_text(README_CONTENT, encoding="UTF-8")
readme_path = self._cachedir / 'README.md'
readme_path.write_text(README_CONTENT, encoding='UTF-8')
gitignore_path = self._cachedir.joinpath(".gitignore")
msg = "# Created by pytest automatically.\n*\n"
gitignore_path.write_text(msg, encoding="UTF-8")
gitignore_path = self._cachedir.joinpath('.gitignore')
msg = '# Created by pytest automatically.\n*\n'
gitignore_path.write_text(msg, encoding='UTF-8')
cachedir_tag_path = self._cachedir.joinpath("CACHEDIR.TAG")
cachedir_tag_path = self._cachedir.joinpath('CACHEDIR.TAG')
cachedir_tag_path.write_bytes(CACHEDIR_TAG_CONTENT)
class LFPluginCollWrapper:
def __init__(self, lfplugin: "LFPlugin") -> None:
def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin
self._collected_at_least_one_failure = False
@hookimpl(wrapper=True)
def pytest_make_collect_report(
self, collector: nodes.Collector
self, collector: nodes.Collector,
) -> Generator[None, CollectReport, CollectReport]:
res = yield
if isinstance(collector, (Session, Directory)):
@ -230,7 +233,7 @@ class LFPluginCollWrapper:
lf_paths = self.lfplugin._last_failed_paths
# Use stable sort to priorize last failed.
def sort_key(node: Union[nodes.Item, nodes.Collector]) -> bool:
def sort_key(node: nodes.Item | nodes.Collector) -> bool:
return node.path in lf_paths
res.result = sorted(
@ -249,7 +252,7 @@ class LFPluginCollWrapper:
if not any(x.nodeid in lastfailed for x in result):
return res
self.lfplugin.config.pluginmanager.register(
LFPluginCollSkipfiles(self.lfplugin), "lfplugin-collskip"
LFPluginCollSkipfiles(self.lfplugin), 'lfplugin-collskip',
)
self._collected_at_least_one_failure = True
@ -257,30 +260,30 @@ class LFPluginCollWrapper:
result[:] = [
x
for x in result
if x.nodeid in lastfailed
if x.nodeid in lastfailed or
# Include any passed arguments (not trivial to filter).
or session.isinitpath(x.path)
session.isinitpath(x.path) or
# Keep all sub-collectors.
or isinstance(x, nodes.Collector)
isinstance(x, nodes.Collector)
]
return res
class LFPluginCollSkipfiles:
def __init__(self, lfplugin: "LFPlugin") -> None:
def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin
@hookimpl
def pytest_make_collect_report(
self, collector: nodes.Collector
) -> Optional[CollectReport]:
self, collector: nodes.Collector,
) -> CollectReport | None:
if isinstance(collector, File):
if collector.path not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1
return CollectReport(
collector.nodeid, "passed", longrepr=None, result=[]
collector.nodeid, 'passed', longrepr=None, result=[],
)
return None
@ -290,44 +293,44 @@ class LFPlugin:
def __init__(self, config: Config) -> None:
self.config = config
active_keys = "lf", "failedfirst"
active_keys = 'lf', 'failedfirst'
self.active = any(config.getoption(key) for key in active_keys)
assert config.cache
self.lastfailed: Dict[str, bool] = config.cache.get("cache/lastfailed", {})
self._previously_failed_count: Optional[int] = None
self._report_status: Optional[str] = None
self.lastfailed: dict[str, bool] = config.cache.get('cache/lastfailed', {})
self._previously_failed_count: int | None = None
self._report_status: str | None = None
self._skipped_files = 0 # count skipped files during collection due to --lf
if config.getoption("lf"):
if config.getoption('lf'):
self._last_failed_paths = self.get_last_failed_paths()
config.pluginmanager.register(
LFPluginCollWrapper(self), "lfplugin-collwrapper"
LFPluginCollWrapper(self), 'lfplugin-collwrapper',
)
def get_last_failed_paths(self) -> Set[Path]:
def get_last_failed_paths(self) -> set[Path]:
"""Return a set with all Paths of the previously failed nodeids and
their parents."""
rootpath = self.config.rootpath
result = set()
for nodeid in self.lastfailed:
path = rootpath / nodeid.split("::")[0]
path = rootpath / nodeid.split('::')[0]
result.add(path)
result.update(path.parents)
return {x for x in result if x.exists()}
def pytest_report_collectionfinish(self) -> Optional[str]:
if self.active and self.config.getoption("verbose") >= 0:
return "run-last-failure: %s" % self._report_status
def pytest_report_collectionfinish(self) -> str | None:
if self.active and self.config.getoption('verbose') >= 0:
return 'run-last-failure: %s' % self._report_status
return None
def pytest_runtest_logreport(self, report: TestReport) -> None:
if (report.when == "call" and report.passed) or report.skipped:
if (report.when == 'call' and report.passed) or report.skipped:
self.lastfailed.pop(report.nodeid, None)
elif report.failed:
self.lastfailed[report.nodeid] = True
def pytest_collectreport(self, report: CollectReport) -> None:
passed = report.outcome in ("passed", "skipped")
passed = report.outcome in ('passed', 'skipped')
if passed:
if report.nodeid in self.lastfailed:
self.lastfailed.pop(report.nodeid)
@ -337,7 +340,7 @@ class LFPlugin:
@hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems(
self, config: Config, items: List[nodes.Item]
self, config: Config, items: list[nodes.Item],
) -> Generator[None, None, None]:
res = yield
@ -357,45 +360,45 @@ class LFPlugin:
if not previously_failed:
# Running a subset of all tests with recorded failures
# only outside of it.
self._report_status = "%d known failures not in selected tests" % (
self._report_status = '%d known failures not in selected tests' % (
len(self.lastfailed),
)
else:
if self.config.getoption("lf"):
if self.config.getoption('lf'):
items[:] = previously_failed
config.hook.pytest_deselected(items=previously_passed)
else: # --failedfirst
items[:] = previously_failed + previously_passed
noun = "failure" if self._previously_failed_count == 1 else "failures"
suffix = " first" if self.config.getoption("failedfirst") else ""
noun = 'failure' if self._previously_failed_count == 1 else 'failures'
suffix = ' first' if self.config.getoption('failedfirst') else ''
self._report_status = (
f"rerun previous {self._previously_failed_count} {noun}{suffix}"
f'rerun previous {self._previously_failed_count} {noun}{suffix}'
)
if self._skipped_files > 0:
files_noun = "file" if self._skipped_files == 1 else "files"
self._report_status += f" (skipped {self._skipped_files} {files_noun})"
files_noun = 'file' if self._skipped_files == 1 else 'files'
self._report_status += f' (skipped {self._skipped_files} {files_noun})'
else:
self._report_status = "no previously failed tests, "
if self.config.getoption("last_failed_no_failures") == "none":
self._report_status += "deselecting all items."
self._report_status = 'no previously failed tests, '
if self.config.getoption('last_failed_no_failures') == 'none':
self._report_status += 'deselecting all items.'
config.hook.pytest_deselected(items=items[:])
items[:] = []
else:
self._report_status += "not deselecting items."
self._report_status += 'not deselecting items.'
return res
def pytest_sessionfinish(self, session: Session) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
if config.getoption('cacheshow') or hasattr(config, 'workerinput'):
return
assert config.cache is not None
saved_lastfailed = config.cache.get("cache/lastfailed", {})
saved_lastfailed = config.cache.get('cache/lastfailed', {})
if saved_lastfailed != self.lastfailed:
config.cache.set("cache/lastfailed", self.lastfailed)
config.cache.set('cache/lastfailed', self.lastfailed)
class NFPlugin:
@ -405,17 +408,17 @@ class NFPlugin:
self.config = config
self.active = config.option.newfirst
assert config.cache is not None
self.cached_nodeids = set(config.cache.get("cache/nodeids", []))
self.cached_nodeids = set(config.cache.get('cache/nodeids', []))
@hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems(
self, items: List[nodes.Item]
self, items: list[nodes.Item],
) -> Generator[None, None, None]:
res = yield
if self.active:
new_items: Dict[str, nodes.Item] = {}
other_items: Dict[str, nodes.Item] = {}
new_items: dict[str, nodes.Item] = {}
other_items: dict[str, nodes.Item] = {}
for item in items:
if item.nodeid not in self.cached_nodeids:
new_items[item.nodeid] = item
@ -423,7 +426,7 @@ class NFPlugin:
other_items[item.nodeid] = item
items[:] = self._get_increasing_order(
new_items.values()
new_items.values(),
) + self._get_increasing_order(other_items.values())
self.cached_nodeids.update(new_items)
else:
@ -431,84 +434,84 @@ class NFPlugin:
return res
def _get_increasing_order(self, items: Iterable[nodes.Item]) -> List[nodes.Item]:
def _get_increasing_order(self, items: Iterable[nodes.Item]) -> list[nodes.Item]:
return sorted(items, key=lambda item: item.path.stat().st_mtime, reverse=True) # type: ignore[no-any-return]
def pytest_sessionfinish(self) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
if config.getoption('cacheshow') or hasattr(config, 'workerinput'):
return
if config.getoption("collectonly"):
if config.getoption('collectonly'):
return
assert config.cache is not None
config.cache.set("cache/nodeids", sorted(self.cached_nodeids))
config.cache.set('cache/nodeids', sorted(self.cached_nodeids))
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group = parser.getgroup('general')
group.addoption(
"--lf",
"--last-failed",
action="store_true",
dest="lf",
help="Rerun only the tests that failed "
"at the last run (or all if none failed)",
'--lf',
'--last-failed',
action='store_true',
dest='lf',
help='Rerun only the tests that failed '
'at the last run (or all if none failed)',
)
group.addoption(
"--ff",
"--failed-first",
action="store_true",
dest="failedfirst",
help="Run all tests, but run the last failures first. "
"This may re-order tests and thus lead to "
"repeated fixture setup/teardown.",
'--ff',
'--failed-first',
action='store_true',
dest='failedfirst',
help='Run all tests, but run the last failures first. '
'This may re-order tests and thus lead to '
'repeated fixture setup/teardown.',
)
group.addoption(
"--nf",
"--new-first",
action="store_true",
dest="newfirst",
help="Run tests from new files first, then the rest of the tests "
"sorted by file mtime",
'--nf',
'--new-first',
action='store_true',
dest='newfirst',
help='Run tests from new files first, then the rest of the tests '
'sorted by file mtime',
)
group.addoption(
"--cache-show",
action="append",
nargs="?",
dest="cacheshow",
'--cache-show',
action='append',
nargs='?',
dest='cacheshow',
help=(
"Show cache contents, don't perform collection or tests. "
"Optional argument: glob (default: '*')."
),
)
group.addoption(
"--cache-clear",
action="store_true",
dest="cacheclear",
help="Remove all cache contents at start of test run",
'--cache-clear',
action='store_true',
dest='cacheclear',
help='Remove all cache contents at start of test run',
)
cache_dir_default = ".pytest_cache"
if "TOX_ENV_DIR" in os.environ:
cache_dir_default = os.path.join(os.environ["TOX_ENV_DIR"], cache_dir_default)
parser.addini("cache_dir", default=cache_dir_default, help="Cache directory path")
cache_dir_default = '.pytest_cache'
if 'TOX_ENV_DIR' in os.environ:
cache_dir_default = os.path.join(os.environ['TOX_ENV_DIR'], cache_dir_default)
parser.addini('cache_dir', default=cache_dir_default, help='Cache directory path')
group.addoption(
"--lfnf",
"--last-failed-no-failures",
action="store",
dest="last_failed_no_failures",
choices=("all", "none"),
default="all",
help="With ``--lf``, determines whether to execute tests when there "
"are no previously (known) failures or when no "
"cached ``lastfailed`` data was found. "
"``all`` (the default) runs the full test suite again. "
"``none`` just emits a message about no known failures and exits successfully.",
'--lfnf',
'--last-failed-no-failures',
action='store',
dest='last_failed_no_failures',
choices=('all', 'none'),
default='all',
help='With ``--lf``, determines whether to execute tests when there '
'are no previously (known) failures or when no '
'cached ``lastfailed`` data was found. '
'``all`` (the default) runs the full test suite again. '
'``none`` just emits a message about no known failures and exits successfully.',
)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.cacheshow and not config.option.help:
from _pytest.main import wrap_session
@ -519,8 +522,8 @@ def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
@hookimpl(tryfirst=True)
def pytest_configure(config: Config) -> None:
config.cache = Cache.for_config(config, _ispytest=True)
config.pluginmanager.register(LFPlugin(config), "lfplugin")
config.pluginmanager.register(NFPlugin(config), "nfplugin")
config.pluginmanager.register(LFPlugin(config), 'lfplugin')
config.pluginmanager.register(NFPlugin(config), 'nfplugin')
@fixture
@ -539,9 +542,9 @@ def cache(request: FixtureRequest) -> Cache:
return request.config.cache
def pytest_report_header(config: Config) -> Optional[str]:
def pytest_report_header(config: Config) -> str | None:
"""Display cachedir with --cache-show and if non-default."""
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache":
if config.option.verbose > 0 or config.getini('cache_dir') != '.pytest_cache':
assert config.cache is not None
cachedir = config.cache._cachedir
# TODO: evaluate generating upward relative paths
@ -551,7 +554,7 @@ def pytest_report_header(config: Config) -> Optional[str]:
displaypath = cachedir.relative_to(config.rootpath)
except ValueError:
displaypath = cachedir
return f"cachedir: {displaypath}"
return f'cachedir: {displaypath}'
return None
@ -561,37 +564,37 @@ def cacheshow(config: Config, session: Session) -> int:
assert config.cache is not None
tw = TerminalWriter()
tw.line("cachedir: " + str(config.cache._cachedir))
tw.line('cachedir: ' + str(config.cache._cachedir))
if not config.cache._cachedir.is_dir():
tw.line("cache is empty")
tw.line('cache is empty')
return 0
glob = config.option.cacheshow[0]
if glob is None:
glob = "*"
glob = '*'
dummy = object()
basedir = config.cache._cachedir
vdir = basedir / Cache._CACHE_PREFIX_VALUES
tw.sep("-", "cache values for %r" % glob)
tw.sep('-', 'cache values for %r' % glob)
for valpath in sorted(x for x in vdir.rglob(glob) if x.is_file()):
key = str(valpath.relative_to(vdir))
val = config.cache.get(key, dummy)
if val is dummy:
tw.line("%s contains unreadable content, will be ignored" % key)
tw.line('%s contains unreadable content, will be ignored' % key)
else:
tw.line("%s contains:" % key)
tw.line('%s contains:' % key)
for line in pformat(val).splitlines():
tw.line(" " + line)
tw.line(' ' + line)
ddir = basedir / Cache._CACHE_PREFIX_DIRS
if ddir.is_dir():
contents = sorted(ddir.rglob(glob))
tw.sep("-", "cache directories for %r" % glob)
tw.sep('-', 'cache directories for %r' % glob)
for p in contents:
# if p.is_dir():
# print("%s/" % p.relative_to(basedir))
if p.is_file():
key = str(p.relative_to(basedir))
tw.line(f"{key} is a file of length {p.stat().st_size:d}")
tw.line(f'{key} is a file of length {p.stat().st_size:d}')
return 0

View file

@ -1,12 +1,14 @@
# mypy: allow-untyped-defs
"""Per-test stdout/stderr capturing mechanism."""
from __future__ import annotations
import abc
import collections
import contextlib
import io
from io import UnsupportedOperation
import os
import sys
from io import UnsupportedOperation
from tempfile import TemporaryFile
from types import TracebackType
from typing import Any
@ -40,25 +42,25 @@ from _pytest.nodes import Item
from _pytest.reports import CollectReport
_CaptureMethod = Literal["fd", "sys", "no", "tee-sys"]
_CaptureMethod = Literal['fd', 'sys', 'no', 'tee-sys']
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group = parser.getgroup('general')
group._addoption(
"--capture",
action="store",
default="fd",
metavar="method",
choices=["fd", "sys", "no", "tee-sys"],
help="Per-test capturing method: one of fd|sys|no|tee-sys",
'--capture',
action='store',
default='fd',
metavar='method',
choices=['fd', 'sys', 'no', 'tee-sys'],
help='Per-test capturing method: one of fd|sys|no|tee-sys',
)
group._addoption(
"-s",
action="store_const",
const="no",
dest="capture",
help="Shortcut for --capture=no",
'-s',
action='store_const',
const='no',
dest='capture',
help='Shortcut for --capture=no',
)
@ -70,7 +72,7 @@ def _colorama_workaround() -> None:
first import of colorama while I/O capture is active, colorama will
fail in various ways.
"""
if sys.platform.startswith("win32"):
if sys.platform.startswith('win32'):
try:
import colorama # noqa: F401
except ImportError:
@ -101,21 +103,21 @@ def _windowsconsoleio_workaround(stream: TextIO) -> None:
See https://github.com/pytest-dev/py/issues/103.
"""
if not sys.platform.startswith("win32") or hasattr(sys, "pypy_version_info"):
if not sys.platform.startswith('win32') or hasattr(sys, 'pypy_version_info'):
return
# Bail out if ``stream`` doesn't seem like a proper ``io`` stream (#2666).
if not hasattr(stream, "buffer"): # type: ignore[unreachable]
if not hasattr(stream, 'buffer'): # type: ignore[unreachable]
return
buffered = hasattr(stream.buffer, "raw")
buffered = hasattr(stream.buffer, 'raw')
raw_stdout = stream.buffer.raw if buffered else stream.buffer # type: ignore[attr-defined]
if not isinstance(raw_stdout, io._WindowsConsoleIO): # type: ignore[attr-defined]
return
def _reopen_stdio(f, mode):
if not buffered and mode[0] == "w":
if not buffered and mode[0] == 'w':
buffering = 0
else:
buffering = -1
@ -128,20 +130,20 @@ def _windowsconsoleio_workaround(stream: TextIO) -> None:
f.line_buffering,
)
sys.stdin = _reopen_stdio(sys.stdin, "rb")
sys.stdout = _reopen_stdio(sys.stdout, "wb")
sys.stderr = _reopen_stdio(sys.stderr, "wb")
sys.stdin = _reopen_stdio(sys.stdin, 'rb')
sys.stdout = _reopen_stdio(sys.stdout, 'wb')
sys.stderr = _reopen_stdio(sys.stderr, 'wb')
@hookimpl(wrapper=True)
def pytest_load_initial_conftests(early_config: Config) -> Generator[None, None, None]:
ns = early_config.known_args_namespace
if ns.capture == "fd":
if ns.capture == 'fd':
_windowsconsoleio_workaround(sys.stdout)
_colorama_workaround()
pluginmanager = early_config.pluginmanager
capman = CaptureManager(ns.capture)
pluginmanager.register(capman, "capturemanager")
pluginmanager.register(capman, 'capturemanager')
# Make sure that capturemanager is properly reset at final shutdown.
early_config.add_cleanup(capman.stop_global_capturing)
@ -176,16 +178,16 @@ class EncodedFile(io.TextIOWrapper):
def mode(self) -> str:
# TextIOWrapper doesn't expose a mode, but at least some of our
# tests check it.
return self.buffer.mode.replace("b", "")
return self.buffer.mode.replace('b', '')
class CaptureIO(io.TextIOWrapper):
def __init__(self) -> None:
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
super().__init__(io.BytesIO(), encoding='UTF-8', newline='', write_through=True)
def getvalue(self) -> str:
assert isinstance(self.buffer, io.BytesIO)
return self.buffer.getvalue().decode("UTF-8")
return self.buffer.getvalue().decode('UTF-8')
class TeeCaptureIO(CaptureIO):
@ -205,7 +207,7 @@ class DontReadFromInput(TextIO):
def read(self, size: int = -1) -> str:
raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`."
'pytest: reading from stdin while output is captured! Consider using `-s`.',
)
readline = read
@ -213,19 +215,19 @@ class DontReadFromInput(TextIO):
def __next__(self) -> str:
return self.readline()
def readlines(self, hint: Optional[int] = -1) -> List[str]:
def readlines(self, hint: int | None = -1) -> list[str]:
raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`."
'pytest: reading from stdin while output is captured! Consider using `-s`.',
)
def __iter__(self) -> Iterator[str]:
return self
def fileno(self) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no fileno()")
raise UnsupportedOperation('redirected stdin is pseudofile, has no fileno()')
def flush(self) -> None:
raise UnsupportedOperation("redirected stdin is pseudofile, has no flush()")
raise UnsupportedOperation('redirected stdin is pseudofile, has no flush()')
def isatty(self) -> bool:
return False
@ -237,34 +239,34 @@ class DontReadFromInput(TextIO):
return False
def seek(self, offset: int, whence: int = 0) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no seek(int)")
raise UnsupportedOperation('redirected stdin is pseudofile, has no seek(int)')
def seekable(self) -> bool:
return False
def tell(self) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()")
raise UnsupportedOperation('redirected stdin is pseudofile, has no tell()')
def truncate(self, size: Optional[int] = None) -> int:
raise UnsupportedOperation("cannot truncate stdin")
def truncate(self, size: int | None = None) -> int:
raise UnsupportedOperation('cannot truncate stdin')
def write(self, data: str) -> int:
raise UnsupportedOperation("cannot write to stdin")
raise UnsupportedOperation('cannot write to stdin')
def writelines(self, lines: Iterable[str]) -> None:
raise UnsupportedOperation("Cannot write to stdin")
raise UnsupportedOperation('Cannot write to stdin')
def writable(self) -> bool:
return False
def __enter__(self) -> "DontReadFromInput":
def __enter__(self) -> DontReadFromInput:
return self
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
type: type[BaseException] | None,
value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass
@ -309,11 +311,11 @@ class CaptureBase(abc.ABC, Generic[AnyStr]):
raise NotImplementedError()
patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"}
patchsysdict = {0: 'stdin', 1: 'stdout', 2: 'stderr'}
class NoCapture(CaptureBase[str]):
EMPTY_BUFFER = ""
EMPTY_BUFFER = ''
def __init__(self, fd: int) -> None:
pass
@ -331,7 +333,7 @@ class NoCapture(CaptureBase[str]):
pass
def snap(self) -> str:
return ""
return ''
def writeorg(self, data: str) -> None:
pass
@ -339,76 +341,76 @@ class NoCapture(CaptureBase[str]):
class SysCaptureBase(CaptureBase[AnyStr]):
def __init__(
self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False
self, fd: int, tmpfile: TextIO | None = None, *, tee: bool = False,
) -> None:
name = patchsysdict[fd]
self._old: TextIO = getattr(sys, name)
self.name = name
if tmpfile is None:
if name == "stdin":
if name == 'stdin':
tmpfile = DontReadFromInput()
else:
tmpfile = CaptureIO() if not tee else TeeCaptureIO(self._old)
self.tmpfile = tmpfile
self._state = "initialized"
self._state = 'initialized'
def repr(self, class_name: str) -> str:
return "<{} {} _old={} _state={!r} tmpfile={!r}>".format(
return '<{} {} _old={} _state={!r} tmpfile={!r}>'.format(
class_name,
self.name,
hasattr(self, "_old") and repr(self._old) or "<UNSET>",
hasattr(self, '_old') and repr(self._old) or '<UNSET>',
self._state,
self.tmpfile,
)
def __repr__(self) -> str:
return "<{} {} _old={} _state={!r} tmpfile={!r}>".format(
return '<{} {} _old={} _state={!r} tmpfile={!r}>'.format(
self.__class__.__name__,
self.name,
hasattr(self, "_old") and repr(self._old) or "<UNSET>",
hasattr(self, '_old') and repr(self._old) or '<UNSET>',
self._state,
self.tmpfile,
)
def _assert_state(self, op: str, states: Tuple[str, ...]) -> None:
def _assert_state(self, op: str, states: tuple[str, ...]) -> None:
assert (
self._state in states
), "cannot {} in state {!r}: expected one of {}".format(
op, self._state, ", ".join(states)
), 'cannot {} in state {!r}: expected one of {}'.format(
op, self._state, ', '.join(states),
)
def start(self) -> None:
self._assert_state("start", ("initialized",))
self._assert_state('start', ('initialized',))
setattr(sys, self.name, self.tmpfile)
self._state = "started"
self._state = 'started'
def done(self) -> None:
self._assert_state("done", ("initialized", "started", "suspended", "done"))
if self._state == "done":
self._assert_state('done', ('initialized', 'started', 'suspended', 'done'))
if self._state == 'done':
return
setattr(sys, self.name, self._old)
del self._old
self.tmpfile.close()
self._state = "done"
self._state = 'done'
def suspend(self) -> None:
self._assert_state("suspend", ("started", "suspended"))
self._assert_state('suspend', ('started', 'suspended'))
setattr(sys, self.name, self._old)
self._state = "suspended"
self._state = 'suspended'
def resume(self) -> None:
self._assert_state("resume", ("started", "suspended"))
if self._state == "started":
self._assert_state('resume', ('started', 'suspended'))
if self._state == 'started':
return
setattr(sys, self.name, self.tmpfile)
self._state = "started"
self._state = 'started'
class SysCaptureBinary(SysCaptureBase[bytes]):
EMPTY_BUFFER = b""
EMPTY_BUFFER = b''
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self._assert_state('snap', ('started', 'suspended'))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
@ -416,17 +418,17 @@ class SysCaptureBinary(SysCaptureBase[bytes]):
return res
def writeorg(self, data: bytes) -> None:
self._assert_state("writeorg", ("started", "suspended"))
self._assert_state('writeorg', ('started', 'suspended'))
self._old.flush()
self._old.buffer.write(data)
self._old.buffer.flush()
class SysCapture(SysCaptureBase[str]):
EMPTY_BUFFER = ""
EMPTY_BUFFER = ''
def snap(self) -> str:
self._assert_state("snap", ("started", "suspended"))
self._assert_state('snap', ('started', 'suspended'))
assert isinstance(self.tmpfile, CaptureIO)
res = self.tmpfile.getvalue()
self.tmpfile.seek(0)
@ -434,7 +436,7 @@ class SysCapture(SysCaptureBase[str]):
return res
def writeorg(self, data: str) -> None:
self._assert_state("writeorg", ("started", "suspended"))
self._assert_state('writeorg', ('started', 'suspended'))
self._old.write(data)
self._old.flush()
@ -457,21 +459,21 @@ class FDCaptureBase(CaptureBase[AnyStr]):
# Further complications are the need to support suspend() and the
# possibility of FD reuse (e.g. the tmpfile getting the very same
# target FD). The following approach is robust, I believe.
self.targetfd_invalid: Optional[int] = os.open(os.devnull, os.O_RDWR)
self.targetfd_invalid: int | None = os.open(os.devnull, os.O_RDWR)
os.dup2(self.targetfd_invalid, targetfd)
else:
self.targetfd_invalid = None
self.targetfd_save = os.dup(targetfd)
if targetfd == 0:
self.tmpfile = open(os.devnull, encoding="utf-8")
self.tmpfile = open(os.devnull, encoding='utf-8')
self.syscapture: CaptureBase[str] = SysCapture(targetfd)
else:
self.tmpfile = EncodedFile(
TemporaryFile(buffering=0),
encoding="utf-8",
errors="replace",
newline="",
encoding='utf-8',
errors='replace',
newline='',
write_through=True,
)
if targetfd in patchsysdict:
@ -479,10 +481,10 @@ class FDCaptureBase(CaptureBase[AnyStr]):
else:
self.syscapture = NoCapture(targetfd)
self._state = "initialized"
self._state = 'initialized'
def __repr__(self) -> str:
return "<{} {} oldfd={} _state={!r} tmpfile={!r}>".format(
return '<{} {} oldfd={} _state={!r} tmpfile={!r}>'.format(
self.__class__.__name__,
self.targetfd,
self.targetfd_save,
@ -490,25 +492,25 @@ class FDCaptureBase(CaptureBase[AnyStr]):
self.tmpfile,
)
def _assert_state(self, op: str, states: Tuple[str, ...]) -> None:
def _assert_state(self, op: str, states: tuple[str, ...]) -> None:
assert (
self._state in states
), "cannot {} in state {!r}: expected one of {}".format(
op, self._state, ", ".join(states)
), 'cannot {} in state {!r}: expected one of {}'.format(
op, self._state, ', '.join(states),
)
def start(self) -> None:
"""Start capturing on targetfd using memorized tmpfile."""
self._assert_state("start", ("initialized",))
self._assert_state('start', ('initialized',))
os.dup2(self.tmpfile.fileno(), self.targetfd)
self.syscapture.start()
self._state = "started"
self._state = 'started'
def done(self) -> None:
"""Stop capturing, restore streams, return original capture file,
seeked to position zero."""
self._assert_state("done", ("initialized", "started", "suspended", "done"))
if self._state == "done":
self._assert_state('done', ('initialized', 'started', 'suspended', 'done'))
if self._state == 'done':
return
os.dup2(self.targetfd_save, self.targetfd)
os.close(self.targetfd_save)
@ -518,23 +520,23 @@ class FDCaptureBase(CaptureBase[AnyStr]):
os.close(self.targetfd_invalid)
self.syscapture.done()
self.tmpfile.close()
self._state = "done"
self._state = 'done'
def suspend(self) -> None:
self._assert_state("suspend", ("started", "suspended"))
if self._state == "suspended":
self._assert_state('suspend', ('started', 'suspended'))
if self._state == 'suspended':
return
self.syscapture.suspend()
os.dup2(self.targetfd_save, self.targetfd)
self._state = "suspended"
self._state = 'suspended'
def resume(self) -> None:
self._assert_state("resume", ("started", "suspended"))
if self._state == "started":
self._assert_state('resume', ('started', 'suspended'))
if self._state == 'started':
return
self.syscapture.resume()
os.dup2(self.tmpfile.fileno(), self.targetfd)
self._state = "started"
self._state = 'started'
class FDCaptureBinary(FDCaptureBase[bytes]):
@ -543,10 +545,10 @@ class FDCaptureBinary(FDCaptureBase[bytes]):
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
EMPTY_BUFFER = b''
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self._assert_state('snap', ('started', 'suspended'))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
@ -555,7 +557,7 @@ class FDCaptureBinary(FDCaptureBase[bytes]):
def writeorg(self, data: bytes) -> None:
"""Write to original file descriptor."""
self._assert_state("writeorg", ("started", "suspended"))
self._assert_state('writeorg', ('started', 'suspended'))
os.write(self.targetfd_save, data)
@ -565,10 +567,10 @@ class FDCapture(FDCaptureBase[str]):
snap() produces text.
"""
EMPTY_BUFFER = ""
EMPTY_BUFFER = ''
def snap(self) -> str:
self._assert_state("snap", ("started", "suspended"))
self._assert_state('snap', ('started', 'suspended'))
self.tmpfile.seek(0)
res = self.tmpfile.read()
self.tmpfile.seek(0)
@ -577,9 +579,9 @@ class FDCapture(FDCaptureBase[str]):
def writeorg(self, data: str) -> None:
"""Write to original file descriptor."""
self._assert_state("writeorg", ("started", "suspended"))
self._assert_state('writeorg', ('started', 'suspended'))
# XXX use encoding of original stream
os.write(self.targetfd_save, data.encode("utf-8"))
os.write(self.targetfd_save, data.encode('utf-8'))
# MultiCapture
@ -598,7 +600,7 @@ if sys.version_info >= (3, 11) or TYPE_CHECKING:
else:
class CaptureResult(
collections.namedtuple("CaptureResult", ["out", "err"]), # noqa: PYI024
collections.namedtuple('CaptureResult', ['out', 'err']), # noqa: PYI024
Generic[AnyStr],
):
"""The result of :method:`caplog.readouterr() <pytest.CaptureFixture.readouterr>`."""
@ -612,16 +614,16 @@ class MultiCapture(Generic[AnyStr]):
def __init__(
self,
in_: Optional[CaptureBase[AnyStr]],
out: Optional[CaptureBase[AnyStr]],
err: Optional[CaptureBase[AnyStr]],
in_: CaptureBase[AnyStr] | None,
out: CaptureBase[AnyStr] | None,
err: CaptureBase[AnyStr] | None,
) -> None:
self.in_: Optional[CaptureBase[AnyStr]] = in_
self.out: Optional[CaptureBase[AnyStr]] = out
self.err: Optional[CaptureBase[AnyStr]] = err
self.in_: CaptureBase[AnyStr] | None = in_
self.out: CaptureBase[AnyStr] | None = out
self.err: CaptureBase[AnyStr] | None = err
def __repr__(self) -> str:
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format(
return '<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>'.format(
self.out,
self.err,
self.in_,
@ -630,7 +632,7 @@ class MultiCapture(Generic[AnyStr]):
)
def start_capturing(self) -> None:
self._state = "started"
self._state = 'started'
if self.in_:
self.in_.start()
if self.out:
@ -638,7 +640,7 @@ class MultiCapture(Generic[AnyStr]):
if self.err:
self.err.start()
def pop_outerr_to_orig(self) -> Tuple[AnyStr, AnyStr]:
def pop_outerr_to_orig(self) -> tuple[AnyStr, AnyStr]:
"""Pop current snapshot out/err capture and flush to orig streams."""
out, err = self.readouterr()
if out:
@ -650,7 +652,7 @@ class MultiCapture(Generic[AnyStr]):
return out, err
def suspend_capturing(self, in_: bool = False) -> None:
self._state = "suspended"
self._state = 'suspended'
if self.out:
self.out.suspend()
if self.err:
@ -660,7 +662,7 @@ class MultiCapture(Generic[AnyStr]):
self._in_suspended = True
def resume_capturing(self) -> None:
self._state = "started"
self._state = 'started'
if self.out:
self.out.resume()
if self.err:
@ -672,9 +674,9 @@ class MultiCapture(Generic[AnyStr]):
def stop_capturing(self) -> None:
"""Stop capturing and reset capturing streams."""
if self._state == "stopped":
raise ValueError("was already stopped")
self._state = "stopped"
if self._state == 'stopped':
raise ValueError('was already stopped')
self._state = 'stopped'
if self.out:
self.out.done()
if self.err:
@ -684,27 +686,27 @@ class MultiCapture(Generic[AnyStr]):
def is_started(self) -> bool:
"""Whether actively capturing -- not suspended or stopped."""
return self._state == "started"
return self._state == 'started'
def readouterr(self) -> CaptureResult[AnyStr]:
out = self.out.snap() if self.out else ""
err = self.err.snap() if self.err else ""
out = self.out.snap() if self.out else ''
err = self.err.snap() if self.err else ''
# TODO: This type error is real, need to fix.
return CaptureResult(out, err) # type: ignore[arg-type]
def _get_multicapture(method: _CaptureMethod) -> MultiCapture[str]:
if method == "fd":
if method == 'fd':
return MultiCapture(in_=FDCapture(0), out=FDCapture(1), err=FDCapture(2))
elif method == "sys":
elif method == 'sys':
return MultiCapture(in_=SysCapture(0), out=SysCapture(1), err=SysCapture(2))
elif method == "no":
elif method == 'no':
return MultiCapture(in_=None, out=None, err=None)
elif method == "tee-sys":
elif method == 'tee-sys':
return MultiCapture(
in_=None, out=SysCapture(1, tee=True), err=SysCapture(2, tee=True)
in_=None, out=SysCapture(1, tee=True), err=SysCapture(2, tee=True),
)
raise ValueError(f"unknown capturing method: {method!r}")
raise ValueError(f'unknown capturing method: {method!r}')
# CaptureManager and CaptureFixture
@ -731,25 +733,25 @@ class CaptureManager:
def __init__(self, method: _CaptureMethod) -> None:
self._method: Final = method
self._global_capturing: Optional[MultiCapture[str]] = None
self._capture_fixture: Optional[CaptureFixture[Any]] = None
self._global_capturing: MultiCapture[str] | None = None
self._capture_fixture: CaptureFixture[Any] | None = None
def __repr__(self) -> str:
return "<CaptureManager _method={!r} _global_capturing={!r} _capture_fixture={!r}>".format(
self._method, self._global_capturing, self._capture_fixture
return '<CaptureManager _method={!r} _global_capturing={!r} _capture_fixture={!r}>'.format(
self._method, self._global_capturing, self._capture_fixture,
)
def is_capturing(self) -> Union[str, bool]:
def is_capturing(self) -> str | bool:
if self.is_globally_capturing():
return "global"
return 'global'
if self._capture_fixture:
return "fixture %s" % self._capture_fixture.request.fixturename
return 'fixture %s' % self._capture_fixture.request.fixturename
return False
# Global capturing control
def is_globally_capturing(self) -> bool:
return self._method != "no"
return self._method != 'no'
def start_global_capturing(self) -> None:
assert self._global_capturing is None
@ -787,12 +789,12 @@ class CaptureManager:
# Fixture Control
def set_fixture(self, capture_fixture: "CaptureFixture[Any]") -> None:
def set_fixture(self, capture_fixture: CaptureFixture[Any]) -> None:
if self._capture_fixture:
current_fixture = self._capture_fixture.request.fixturename
requested_fixture = capture_fixture.request.fixturename
capture_fixture.request.raiseerror(
f"cannot use {requested_fixture} and {current_fixture} at the same time"
f'cannot use {requested_fixture} and {current_fixture} at the same time',
)
self._capture_fixture = capture_fixture
@ -848,14 +850,14 @@ class CaptureManager:
self.suspend_global_capture(in_=False)
out, err = self.read_global_capture()
item.add_report_section(when, "stdout", out)
item.add_report_section(when, "stderr", err)
item.add_report_section(when, 'stdout', out)
item.add_report_section(when, 'stderr', err)
# Hooks
@hookimpl(wrapper=True)
def pytest_make_collect_report(
self, collector: Collector
self, collector: Collector,
) -> Generator[None, CollectReport, CollectReport]:
if isinstance(collector, File):
self.resume_global_capture()
@ -865,26 +867,26 @@ class CaptureManager:
self.suspend_global_capture()
out, err = self.read_global_capture()
if out:
rep.sections.append(("Captured stdout", out))
rep.sections.append(('Captured stdout', out))
if err:
rep.sections.append(("Captured stderr", err))
rep.sections.append(('Captured stderr', err))
else:
rep = yield
return rep
@hookimpl(wrapper=True)
def pytest_runtest_setup(self, item: Item) -> Generator[None, None, None]:
with self.item_capture("setup", item):
with self.item_capture('setup', item):
return (yield)
@hookimpl(wrapper=True)
def pytest_runtest_call(self, item: Item) -> Generator[None, None, None]:
with self.item_capture("call", item):
with self.item_capture('call', item):
return (yield)
@hookimpl(wrapper=True)
def pytest_runtest_teardown(self, item: Item) -> Generator[None, None, None]:
with self.item_capture("teardown", item):
with self.item_capture('teardown', item):
return (yield)
@hookimpl(tryfirst=True)
@ -902,15 +904,15 @@ class CaptureFixture(Generic[AnyStr]):
def __init__(
self,
captureclass: Type[CaptureBase[AnyStr]],
captureclass: type[CaptureBase[AnyStr]],
request: SubRequest,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self.captureclass: Type[CaptureBase[AnyStr]] = captureclass
self.captureclass: type[CaptureBase[AnyStr]] = captureclass
self.request = request
self._capture: Optional[MultiCapture[AnyStr]] = None
self._capture: MultiCapture[AnyStr] | None = None
self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER
self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER
@ -968,7 +970,7 @@ class CaptureFixture(Generic[AnyStr]):
def disabled(self) -> Generator[None, None, None]:
"""Temporarily disable capturing while inside the ``with`` block."""
capmanager: CaptureManager = self.request.config.pluginmanager.getplugin(
"capturemanager"
'capturemanager',
)
with capmanager.global_and_fixture_disabled():
yield
@ -995,7 +997,7 @@ def capsys(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
captured = capsys.readouterr()
assert captured.out == "hello\n"
"""
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capman: CaptureManager = request.config.pluginmanager.getplugin('capturemanager')
capture_fixture = CaptureFixture(SysCapture, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
@ -1022,7 +1024,7 @@ def capsysbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None,
captured = capsysbinary.readouterr()
assert captured.out == b"hello\n"
"""
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capman: CaptureManager = request.config.pluginmanager.getplugin('capturemanager')
capture_fixture = CaptureFixture(SysCaptureBinary, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
@ -1049,7 +1051,7 @@ def capfd(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
captured = capfd.readouterr()
assert captured.out == "hello\n"
"""
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capman: CaptureManager = request.config.pluginmanager.getplugin('capturemanager')
capture_fixture = CaptureFixture(FDCapture, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
@ -1077,7 +1079,7 @@ def capfdbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, N
assert captured.out == b"hello\n"
"""
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capman: CaptureManager = request.config.pluginmanager.getplugin('capturemanager')
capture_fixture = CaptureFixture(FDCaptureBinary, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()

View file

@ -1,17 +1,16 @@
# mypy: allow-untyped-defs
"""Python version compatibility code."""
from __future__ import annotations
import dataclasses
import enum
import functools
import inspect
import os
import sys
from inspect import Parameter
from inspect import signature
import os
from pathlib import Path
import sys
from typing import Any
from typing import Callable
from typing import Final
@ -57,7 +56,7 @@ def iscoroutinefunction(func: object) -> bool:
importing asyncio directly, which in turns also initializes the "logging"
module as a side-effect (see issue #8).
"""
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
return inspect.iscoroutinefunction(func) or getattr(func, '_is_coroutine', False)
def is_async_function(func: object) -> bool:
@ -76,33 +75,33 @@ def getlocation(function, curdir: str | os.PathLike[str] | None = None) -> str:
except ValueError:
pass
else:
return "%s:%d" % (relfn, lineno + 1)
return "%s:%d" % (fn, lineno + 1)
return '%s:%d' % (relfn, lineno + 1)
return '%s:%d' % (fn, lineno + 1)
def num_mock_patch_args(function) -> int:
"""Return number of arguments used up by mock arguments (if any)."""
patchings = getattr(function, "patchings", None)
patchings = getattr(function, 'patchings', None)
if not patchings:
return 0
mock_sentinel = getattr(sys.modules.get("mock"), "DEFAULT", object())
ut_mock_sentinel = getattr(sys.modules.get("unittest.mock"), "DEFAULT", object())
mock_sentinel = getattr(sys.modules.get('mock'), 'DEFAULT', object())
ut_mock_sentinel = getattr(sys.modules.get('unittest.mock'), 'DEFAULT', object())
return len(
[
p
for p in patchings
if not p.attribute_name
and (p.new is mock_sentinel or p.new is ut_mock_sentinel)
]
if not p.attribute_name and
(p.new is mock_sentinel or p.new is ut_mock_sentinel)
],
)
def getfuncargnames(
function: Callable[..., object],
*,
name: str = "",
name: str = '',
is_method: bool = False,
cls: type | None = None,
) -> tuple[str, ...]:
@ -135,7 +134,7 @@ def getfuncargnames(
from _pytest.outcomes import fail
fail(
f"Could not determine arguments of {function!r}: {e}",
f'Could not determine arguments of {function!r}: {e}',
pytrace=False,
)
@ -143,10 +142,10 @@ def getfuncargnames(
p.name
for p in parameters.values()
if (
p.kind is Parameter.POSITIONAL_OR_KEYWORD
or p.kind is Parameter.KEYWORD_ONLY
)
and p.default is Parameter.empty
p.kind is Parameter.POSITIONAL_OR_KEYWORD or
p.kind is Parameter.KEYWORD_ONLY
) and
p.default is Parameter.empty
)
if not name:
name = function.__name__
@ -157,15 +156,15 @@ def getfuncargnames(
if is_method or (
# Not using `getattr` because we don't want to resolve the staticmethod.
# Not using `cls.__dict__` because we want to check the entire MRO.
cls
and not isinstance(
inspect.getattr_static(cls, name, default=None), staticmethod
cls and
not isinstance(
inspect.getattr_static(cls, name, default=None), staticmethod,
)
):
arg_names = arg_names[1:]
# Remove any names that will be replaced with mocks.
if hasattr(function, "__wrapped__"):
arg_names = arg_names[num_mock_patch_args(function) :]
if hasattr(function, '__wrapped__'):
arg_names = arg_names[num_mock_patch_args(function):]
return arg_names
@ -176,16 +175,16 @@ def get_default_arg_names(function: Callable[..., Any]) -> tuple[str, ...]:
return tuple(
p.name
for p in signature(function).parameters.values()
if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
and p.default is not Parameter.empty
if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY) and
p.default is not Parameter.empty
)
_non_printable_ascii_translate_table = {
i: f"\\x{i:02x}" for i in range(128) if i not in range(32, 127)
i: f'\\x{i:02x}' for i in range(128) if i not in range(32, 127)
}
_non_printable_ascii_translate_table.update(
{ord("\t"): "\\t", ord("\r"): "\\r", ord("\n"): "\\n"}
{ord('\t'): '\\t', ord('\r'): '\\r', ord('\n'): '\\n'},
)
@ -206,9 +205,9 @@ def ascii_escaped(val: bytes | str) -> str:
a UTF-8 string.
"""
if isinstance(val, bytes):
ret = val.decode("ascii", "backslashreplace")
ret = val.decode('ascii', 'backslashreplace')
else:
ret = val.encode("unicode_escape").decode("ascii")
ret = val.encode('unicode_escape').decode('ascii')
return ret.translate(_non_printable_ascii_translate_table)
@ -232,11 +231,11 @@ def get_real_func(obj):
# __pytest_wrapped__ is set by @pytest.fixture when wrapping the fixture function
# to trigger a warning if it gets called directly instead of by pytest: we don't
# want to unwrap further than this otherwise we lose useful wrappings like @mock.patch (#3774)
new_obj = getattr(obj, "__pytest_wrapped__", None)
new_obj = getattr(obj, '__pytest_wrapped__', None)
if isinstance(new_obj, _PytestWrapper):
obj = new_obj.obj
break
new_obj = getattr(obj, "__wrapped__", None)
new_obj = getattr(obj, '__wrapped__', None)
if new_obj is None:
break
obj = new_obj
@ -244,7 +243,7 @@ def get_real_func(obj):
from _pytest._io.saferepr import saferepr
raise ValueError(
f"could not find real function of {saferepr(start_obj)}\nstopped at {saferepr(obj)}"
f'could not find real function of {saferepr(start_obj)}\nstopped at {saferepr(obj)}',
)
if isinstance(obj, functools.partial):
obj = obj.func
@ -256,11 +255,11 @@ def get_real_method(obj, holder):
``obj``, while at the same time returning a bound method to ``holder`` if
the original object was a bound method."""
try:
is_method = hasattr(obj, "__func__")
is_method = hasattr(obj, '__func__')
obj = get_real_func(obj)
except Exception: # pragma: no cover
return obj
if is_method and hasattr(obj, "__get__") and callable(obj.__get__):
if is_method and hasattr(obj, '__get__') and callable(obj.__get__):
obj = obj.__get__(holder)
return obj
@ -306,7 +305,7 @@ def get_user_id() -> int | None:
# mypy follows the version and platform checking expectation of PEP 484:
# https://mypy.readthedocs.io/en/stable/common_issues.html?highlight=platform#python-version-and-system-platform-checks
# Containment checks are too complex for mypy v1.5.0 and cause failure.
if sys.platform == "win32" or sys.platform == "emscripten": # noqa: PLR1714
if sys.platform == 'win32' or sys.platform == 'emscripten': # noqa: PLR1714
# win32 does not have a getuid() function.
# Emscripten has a return 0 stub.
return None
@ -350,4 +349,4 @@ def get_user_id() -> int | None:
#
# This also work for Enums (if you use `is` to compare) and Literals.
def assert_never(value: NoReturn) -> NoReturn:
assert False, f"Unhandled value: {value} ({type(value).__name__})"
assert False, f'Unhandled value: {value} ({type(value).__name__})'

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,10 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import argparse
from gettext import gettext
import os
import sys
from gettext import gettext
from typing import Any
from typing import Callable
from typing import cast
@ -22,12 +24,12 @@ from _pytest.config.exceptions import UsageError
from _pytest.deprecated import check_ispytest
FILE_OR_DIR = "file_or_dir"
FILE_OR_DIR = 'file_or_dir'
class NotSet:
def __repr__(self) -> str:
return "<notset>"
return '<notset>'
NOT_SET = NotSet()
@ -41,32 +43,32 @@ class Parser:
there's an error processing the command line arguments.
"""
prog: Optional[str] = None
prog: str | None = None
def __init__(
self,
usage: Optional[str] = None,
processopt: Optional[Callable[["Argument"], None]] = None,
usage: str | None = None,
processopt: Callable[[Argument], None] | None = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self._anonymous = OptionGroup("Custom options", parser=self, _ispytest=True)
self._groups: List[OptionGroup] = []
self._anonymous = OptionGroup('Custom options', parser=self, _ispytest=True)
self._groups: list[OptionGroup] = []
self._processopt = processopt
self._usage = usage
self._inidict: Dict[str, Tuple[str, Optional[str], Any]] = {}
self._ininames: List[str] = []
self.extra_info: Dict[str, Any] = {}
self._inidict: dict[str, tuple[str, str | None, Any]] = {}
self._ininames: list[str] = []
self.extra_info: dict[str, Any] = {}
def processoption(self, option: "Argument") -> None:
def processoption(self, option: Argument) -> None:
if self._processopt:
if option.dest:
self._processopt(option)
def getgroup(
self, name: str, description: str = "", after: Optional[str] = None
) -> "OptionGroup":
self, name: str, description: str = '', after: str | None = None,
) -> OptionGroup:
"""Get (or create) a named option Group.
:param name: Name of the option group.
@ -108,8 +110,8 @@ class Parser:
def parse(
self,
args: Sequence[Union[str, "os.PathLike[str]"]],
namespace: Optional[argparse.Namespace] = None,
args: Sequence[str | os.PathLike[str]],
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
from _pytest._argcomplete import try_argcomplete
@ -118,7 +120,7 @@ class Parser:
strargs = [os.fspath(x) for x in args]
return self.optparser.parse_args(strargs, namespace=namespace)
def _getparser(self) -> "MyOptionParser":
def _getparser(self) -> MyOptionParser:
from _pytest._argcomplete import filescompleter
optparser = MyOptionParser(self, self.extra_info, prog=self.prog)
@ -131,7 +133,7 @@ class Parser:
n = option.names()
a = option.attrs()
arggroup.add_argument(*n, **a)
file_or_dir_arg = optparser.add_argument(FILE_OR_DIR, nargs="*")
file_or_dir_arg = optparser.add_argument(FILE_OR_DIR, nargs='*')
# bash like autocompletion for dirs (appending '/')
# Type ignored because typeshed doesn't know about argcomplete.
file_or_dir_arg.completer = filescompleter # type: ignore
@ -139,10 +141,10 @@ class Parser:
def parse_setoption(
self,
args: Sequence[Union[str, "os.PathLike[str]"]],
args: Sequence[str | os.PathLike[str]],
option: argparse.Namespace,
namespace: Optional[argparse.Namespace] = None,
) -> List[str]:
namespace: argparse.Namespace | None = None,
) -> list[str]:
parsedoption = self.parse(args, namespace=namespace)
for name, value in parsedoption.__dict__.items():
setattr(option, name, value)
@ -150,8 +152,8 @@ class Parser:
def parse_known_args(
self,
args: Sequence[Union[str, "os.PathLike[str]"]],
namespace: Optional[argparse.Namespace] = None,
args: Sequence[str | os.PathLike[str]],
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
"""Parse the known arguments at this point.
@ -161,9 +163,9 @@ class Parser:
def parse_known_and_unknown_args(
self,
args: Sequence[Union[str, "os.PathLike[str]"]],
namespace: Optional[argparse.Namespace] = None,
) -> Tuple[argparse.Namespace, List[str]]:
args: Sequence[str | os.PathLike[str]],
namespace: argparse.Namespace | None = None,
) -> tuple[argparse.Namespace, list[str]]:
"""Parse the known arguments at this point, and also return the
remaining unknown arguments.
@ -179,9 +181,9 @@ class Parser:
self,
name: str,
help: str,
type: Optional[
Literal["string", "paths", "pathlist", "args", "linelist", "bool"]
] = None,
type: None | (
Literal['string', 'paths', 'pathlist', 'args', 'linelist', 'bool']
) = None,
default: Any = NOT_SET,
) -> None:
"""Register an ini-file option.
@ -215,7 +217,7 @@ class Parser:
The value of ini-variables can be retrieved via a call to
:py:func:`config.getini(name) <pytest.Config.getini>`.
"""
assert type in (None, "string", "paths", "pathlist", "args", "linelist", "bool")
assert type in (None, 'string', 'paths', 'pathlist', 'args', 'linelist', 'bool')
if default is NOT_SET:
default = get_ini_default_for_type(type)
@ -224,33 +226,33 @@ class Parser:
def get_ini_default_for_type(
type: Optional[Literal["string", "paths", "pathlist", "args", "linelist", "bool"]],
type: Literal['string', 'paths', 'pathlist', 'args', 'linelist', 'bool'] | None,
) -> Any:
"""
Used by addini to get the default value for a given ini-option type, when
default is not supplied.
"""
if type is None:
return ""
elif type in ("paths", "pathlist", "args", "linelist"):
return ''
elif type in ('paths', 'pathlist', 'args', 'linelist'):
return []
elif type == "bool":
elif type == 'bool':
return False
else:
return ""
return ''
class ArgumentError(Exception):
"""Raised if an Argument instance is created with invalid or
inconsistent arguments."""
def __init__(self, msg: str, option: Union["Argument", str]) -> None:
def __init__(self, msg: str, option: Argument | str) -> None:
self.msg = msg
self.option_id = str(option)
def __str__(self) -> str:
if self.option_id:
return f"option {self.option_id}: {self.msg}"
return f'option {self.option_id}: {self.msg}'
else:
return self.msg
@ -267,36 +269,36 @@ class Argument:
def __init__(self, *names: str, **attrs: Any) -> None:
"""Store params in private vars for use in add_argument."""
self._attrs = attrs
self._short_opts: List[str] = []
self._long_opts: List[str] = []
self._short_opts: list[str] = []
self._long_opts: list[str] = []
try:
self.type = attrs["type"]
self.type = attrs['type']
except KeyError:
pass
try:
# Attribute existence is tested in Config._processopt.
self.default = attrs["default"]
self.default = attrs['default']
except KeyError:
pass
self._set_opt_strings(names)
dest: Optional[str] = attrs.get("dest")
dest: str | None = attrs.get('dest')
if dest:
self.dest = dest
elif self._long_opts:
self.dest = self._long_opts[0][2:].replace("-", "_")
self.dest = self._long_opts[0][2:].replace('-', '_')
else:
try:
self.dest = self._short_opts[0][1:]
except IndexError as e:
self.dest = "???" # Needed for the error repr.
raise ArgumentError("need a long or short option", self) from e
self.dest = '???' # Needed for the error repr.
raise ArgumentError('need a long or short option', self) from e
def names(self) -> List[str]:
def names(self) -> list[str]:
return self._short_opts + self._long_opts
def attrs(self) -> Mapping[str, Any]:
# Update any attributes set by processopt.
attrs = "default dest help".split()
attrs = 'default dest help'.split()
attrs.append(self.dest)
for attr in attrs:
try:
@ -313,39 +315,39 @@ class Argument:
for opt in opts:
if len(opt) < 2:
raise ArgumentError(
"invalid option string %r: "
"must be at least two characters long" % opt,
'invalid option string %r: '
'must be at least two characters long' % opt,
self,
)
elif len(opt) == 2:
if not (opt[0] == "-" and opt[1] != "-"):
if not (opt[0] == '-' and opt[1] != '-'):
raise ArgumentError(
"invalid short option string %r: "
"must be of the form -x, (x any non-dash char)" % opt,
'invalid short option string %r: '
'must be of the form -x, (x any non-dash char)' % opt,
self,
)
self._short_opts.append(opt)
else:
if not (opt[0:2] == "--" and opt[2] != "-"):
if not (opt[0:2] == '--' and opt[2] != '-'):
raise ArgumentError(
"invalid long option string %r: "
"must start with --, followed by non-dash" % opt,
'invalid long option string %r: '
'must start with --, followed by non-dash' % opt,
self,
)
self._long_opts.append(opt)
def __repr__(self) -> str:
args: List[str] = []
args: list[str] = []
if self._short_opts:
args += ["_short_opts: " + repr(self._short_opts)]
args += ['_short_opts: ' + repr(self._short_opts)]
if self._long_opts:
args += ["_long_opts: " + repr(self._long_opts)]
args += ["dest: " + repr(self.dest)]
if hasattr(self, "type"):
args += ["type: " + repr(self.type)]
if hasattr(self, "default"):
args += ["default: " + repr(self.default)]
return "Argument({})".format(", ".join(args))
args += ['_long_opts: ' + repr(self._long_opts)]
args += ['dest: ' + repr(self.dest)]
if hasattr(self, 'type'):
args += ['type: ' + repr(self.type)]
if hasattr(self, 'default'):
args += ['default: ' + repr(self.default)]
return 'Argument({})'.format(', '.join(args))
class OptionGroup:
@ -354,15 +356,15 @@ class OptionGroup:
def __init__(
self,
name: str,
description: str = "",
parser: Optional[Parser] = None,
description: str = '',
parser: Parser | None = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self.name = name
self.description = description
self.options: List[Argument] = []
self.options: list[Argument] = []
self.parser = parser
def addoption(self, *opts: str, **attrs: Any) -> None:
@ -383,7 +385,7 @@ class OptionGroup:
name for opt in self.options for name in opt.names()
)
if conflict:
raise ValueError("option names %s already added" % conflict)
raise ValueError('option names %s already added' % conflict)
option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=False)
@ -391,11 +393,11 @@ class OptionGroup:
option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=True)
def _addoption_instance(self, option: "Argument", shortupper: bool = False) -> None:
def _addoption_instance(self, option: Argument, shortupper: bool = False) -> None:
if not shortupper:
for opt in option._short_opts:
if opt[0] == "-" and opt[1].islower():
raise ValueError("lowercase shortoptions reserved")
if opt[0] == '-' and opt[1].islower():
raise ValueError('lowercase shortoptions reserved')
if self.parser:
self.parser.processoption(option)
self.options.append(option)
@ -405,8 +407,8 @@ class MyOptionParser(argparse.ArgumentParser):
def __init__(
self,
parser: Parser,
extra_info: Optional[Dict[str, Any]] = None,
prog: Optional[str] = None,
extra_info: dict[str, Any] | None = None,
prog: str | None = None,
) -> None:
self._parser = parser
super().__init__(
@ -422,29 +424,29 @@ class MyOptionParser(argparse.ArgumentParser):
def error(self, message: str) -> NoReturn:
"""Transform argparse error message into UsageError."""
msg = f"{self.prog}: error: {message}"
msg = f'{self.prog}: error: {message}'
if hasattr(self._parser, "_config_source_hint"):
if hasattr(self._parser, '_config_source_hint'):
# Type ignored because the attribute is set dynamically.
msg = f"{msg} ({self._parser._config_source_hint})" # type: ignore
msg = f'{msg} ({self._parser._config_source_hint})' # type: ignore
raise UsageError(self.format_usage() + msg)
# Type ignored because typeshed has a very complex type in the superclass.
def parse_args( # type: ignore
self,
args: Optional[Sequence[str]] = None,
namespace: Optional[argparse.Namespace] = None,
args: Sequence[str] | None = None,
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
"""Allow splitting of positional arguments."""
parsed, unrecognized = self.parse_known_args(args, namespace)
if unrecognized:
for arg in unrecognized:
if arg and arg[0] == "-":
lines = ["unrecognized arguments: %s" % (" ".join(unrecognized))]
if arg and arg[0] == '-':
lines = ['unrecognized arguments: %s' % (' '.join(unrecognized))]
for k, v in sorted(self.extra_info.items()):
lines.append(f" {k}: {v}")
self.error("\n".join(lines))
lines.append(f' {k}: {v}')
self.error('\n'.join(lines))
getattr(parsed, FILE_OR_DIR).extend(unrecognized)
return parsed
@ -452,8 +454,8 @@ class MyOptionParser(argparse.ArgumentParser):
# Backport of https://github.com/python/cpython/pull/14316 so we can
# disable long --argument abbreviations without breaking short flags.
def _parse_optional(
self, arg_string: str
) -> Optional[Tuple[Optional[argparse.Action], str, Optional[str]]]:
self, arg_string: str,
) -> tuple[argparse.Action | None, str, str | None] | None:
if not arg_string:
return None
if arg_string[0] not in self.prefix_chars:
@ -463,26 +465,26 @@ class MyOptionParser(argparse.ArgumentParser):
return action, arg_string, None
if len(arg_string) == 1:
return None
if "=" in arg_string:
option_string, explicit_arg = arg_string.split("=", 1)
if '=' in arg_string:
option_string, explicit_arg = arg_string.split('=', 1)
if option_string in self._option_string_actions:
action = self._option_string_actions[option_string]
return action, option_string, explicit_arg
if self.allow_abbrev or not arg_string.startswith("--"):
if self.allow_abbrev or not arg_string.startswith('--'):
option_tuples = self._get_option_tuples(arg_string)
if len(option_tuples) > 1:
msg = gettext(
"ambiguous option: %(option)s could match %(matches)s"
'ambiguous option: %(option)s could match %(matches)s',
)
options = ", ".join(option for _, option, _ in option_tuples)
self.error(msg % {"option": arg_string, "matches": options})
options = ', '.join(option for _, option, _ in option_tuples)
self.error(msg % {'option': arg_string, 'matches': options})
elif len(option_tuples) == 1:
(option_tuple,) = option_tuples
return option_tuple
if self._negative_number_matcher.match(arg_string):
if not self._has_negative_number_optionals:
return None
if " " in arg_string:
if ' ' in arg_string:
return None
return None, arg_string, None
@ -497,45 +499,45 @@ class DropShorterLongHelpFormatter(argparse.HelpFormatter):
def __init__(self, *args: Any, **kwargs: Any) -> None:
# Use more accurate terminal width.
if "width" not in kwargs:
kwargs["width"] = _pytest._io.get_terminal_width()
if 'width' not in kwargs:
kwargs['width'] = _pytest._io.get_terminal_width()
super().__init__(*args, **kwargs)
def _format_action_invocation(self, action: argparse.Action) -> str:
orgstr = super()._format_action_invocation(action)
if orgstr and orgstr[0] != "-": # only optional arguments
if orgstr and orgstr[0] != '-': # only optional arguments
return orgstr
res: Optional[str] = getattr(action, "_formatted_action_invocation", None)
res: str | None = getattr(action, '_formatted_action_invocation', None)
if res:
return res
options = orgstr.split(", ")
options = orgstr.split(', ')
if len(options) == 2 and (len(options[0]) == 2 or len(options[1]) == 2):
# a shortcut for '-h, --help' or '--abc', '-a'
action._formatted_action_invocation = orgstr # type: ignore
return orgstr
return_list = []
short_long: Dict[str, str] = {}
short_long: dict[str, str] = {}
for option in options:
if len(option) == 2 or option[2] == " ":
if len(option) == 2 or option[2] == ' ':
continue
if not option.startswith("--"):
if not option.startswith('--'):
raise ArgumentError(
'long optional argument without "--": [%s]' % (option), option
'long optional argument without "--": [%s]' % (option), option,
)
xxoption = option[2:]
shortened = xxoption.replace("-", "")
shortened = xxoption.replace('-', '')
if shortened not in short_long or len(short_long[shortened]) < len(
xxoption
xxoption,
):
short_long[shortened] = xxoption
# now short_long has been filled out to the longest with dashes
# **and** we keep the right option ordering from add_argument
for option in options:
if len(option) == 2 or option[2] == " ":
if len(option) == 2 or option[2] == ' ':
return_list.append(option)
if option[2:] == short_long.get(option.replace("-", "")):
return_list.append(option.replace(" ", "=", 1))
formatted_action_invocation = ", ".join(return_list)
if option[2:] == short_long.get(option.replace('-', '')):
return_list.append(option.replace(' ', '=', 1))
formatted_action_invocation = ', '.join(return_list)
action._formatted_action_invocation = formatted_action_invocation # type: ignore
return formatted_action_invocation

View file

@ -1,10 +1,10 @@
from __future__ import annotations
import functools
import warnings
from pathlib import Path
from typing import Any
from typing import Mapping
import warnings
import pluggy
@ -15,19 +15,19 @@ from ..deprecated import HOOK_LEGACY_PATH_ARG
# hookname: (Path, LEGACY_PATH)
imply_paths_hooks: Mapping[str, tuple[str, str]] = {
"pytest_ignore_collect": ("collection_path", "path"),
"pytest_collect_file": ("file_path", "path"),
"pytest_pycollect_makemodule": ("module_path", "path"),
"pytest_report_header": ("start_path", "startdir"),
"pytest_report_collectionfinish": ("start_path", "startdir"),
'pytest_ignore_collect': ('collection_path', 'path'),
'pytest_collect_file': ('file_path', 'path'),
'pytest_pycollect_makemodule': ('module_path', 'path'),
'pytest_report_header': ('start_path', 'startdir'),
'pytest_report_collectionfinish': ('start_path', 'startdir'),
}
def _check_path(path: Path, fspath: LEGACY_PATH) -> None:
if Path(fspath) != path:
raise ValueError(
f"Path({fspath!r}) != {path!r}\n"
"if both path and fspath are given they need to be equal"
f'Path({fspath!r}) != {path!r}\n'
'if both path and fspath are given they need to be equal',
)
@ -61,7 +61,7 @@ class PathAwareHookProxy:
if fspath_value is not None:
warnings.warn(
HOOK_LEGACY_PATH_ARG.format(
pylib_path_arg=fspath_var, pathlib_path_arg=path_var
pylib_path_arg=fspath_var, pathlib_path_arg=path_var,
),
stacklevel=2,
)

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import final

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import os
from pathlib import Path
import sys
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
@ -10,13 +12,13 @@ from typing import Tuple
from typing import Union
import iniconfig
from .exceptions import UsageError
from _pytest.outcomes import fail
from _pytest.pathlib import absolutepath
from _pytest.pathlib import commonpath
from _pytest.pathlib import safe_exists
from .exceptions import UsageError
def _parse_ini_config(path: Path) -> iniconfig.IniConfig:
"""Parse the given generic '.ini' file using legacy IniConfig parser, returning
@ -32,52 +34,52 @@ def _parse_ini_config(path: Path) -> iniconfig.IniConfig:
def load_config_dict_from_file(
filepath: Path,
) -> Optional[Dict[str, Union[str, List[str]]]]:
) -> dict[str, str | list[str]] | None:
"""Load pytest configuration from the given file path, if supported.
Return None if the file does not contain valid pytest configuration.
"""
# Configuration from ini files are obtained from the [pytest] section, if present.
if filepath.suffix == ".ini":
if filepath.suffix == '.ini':
iniconfig = _parse_ini_config(filepath)
if "pytest" in iniconfig:
return dict(iniconfig["pytest"].items())
if 'pytest' in iniconfig:
return dict(iniconfig['pytest'].items())
else:
# "pytest.ini" files are always the source of configuration, even if empty.
if filepath.name == "pytest.ini":
if filepath.name == 'pytest.ini':
return {}
# '.cfg' files are considered if they contain a "[tool:pytest]" section.
elif filepath.suffix == ".cfg":
elif filepath.suffix == '.cfg':
iniconfig = _parse_ini_config(filepath)
if "tool:pytest" in iniconfig.sections:
return dict(iniconfig["tool:pytest"].items())
elif "pytest" in iniconfig.sections:
if 'tool:pytest' in iniconfig.sections:
return dict(iniconfig['tool:pytest'].items())
elif 'pytest' in iniconfig.sections:
# If a setup.cfg contains a "[pytest]" section, we raise a failure to indicate users that
# plain "[pytest]" sections in setup.cfg files is no longer supported (#3086).
fail(CFG_PYTEST_SECTION.format(filename="setup.cfg"), pytrace=False)
fail(CFG_PYTEST_SECTION.format(filename='setup.cfg'), pytrace=False)
# '.toml' files are considered if they contain a [tool.pytest.ini_options] table.
elif filepath.suffix == ".toml":
elif filepath.suffix == '.toml':
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
toml_text = filepath.read_text(encoding="utf-8")
toml_text = filepath.read_text(encoding='utf-8')
try:
config = tomllib.loads(toml_text)
except tomllib.TOMLDecodeError as exc:
raise UsageError(f"{filepath}: {exc}") from exc
raise UsageError(f'{filepath}: {exc}') from exc
result = config.get("tool", {}).get("pytest", {}).get("ini_options", None)
result = config.get('tool', {}).get('pytest', {}).get('ini_options', None)
if result is not None:
# TOML supports richer data types than ini files (strings, arrays, floats, ints, etc),
# however we need to convert all scalar values to str for compatibility with the rest
# of the configuration system, which expects strings only.
def make_scalar(v: object) -> Union[str, List[str]]:
def make_scalar(v: object) -> str | list[str]:
return v if isinstance(v, list) else str(v)
return {k: make_scalar(v) for k, v in result.items()}
@ -88,27 +90,27 @@ def load_config_dict_from_file(
def locate_config(
invocation_dir: Path,
args: Iterable[Path],
) -> Tuple[Optional[Path], Optional[Path], Dict[str, Union[str, List[str]]]]:
) -> tuple[Path | None, Path | None, dict[str, str | list[str]]]:
"""Search in the list of arguments for a valid ini-file for pytest,
and return a tuple of (rootdir, inifile, cfg-dict)."""
config_names = [
"pytest.ini",
".pytest.ini",
"pyproject.toml",
"tox.ini",
"setup.cfg",
'pytest.ini',
'.pytest.ini',
'pyproject.toml',
'tox.ini',
'setup.cfg',
]
args = [x for x in args if not str(x).startswith("-")]
args = [x for x in args if not str(x).startswith('-')]
if not args:
args = [invocation_dir]
found_pyproject_toml: Optional[Path] = None
found_pyproject_toml: Path | None = None
for arg in args:
argpath = absolutepath(arg)
for base in (argpath, *argpath.parents):
for config_name in config_names:
p = base / config_name
if p.is_file():
if p.name == "pyproject.toml" and found_pyproject_toml is None:
if p.name == 'pyproject.toml' and found_pyproject_toml is None:
found_pyproject_toml = p
ini_config = load_config_dict_from_file(p)
if ini_config is not None:
@ -122,7 +124,7 @@ def get_common_ancestor(
invocation_dir: Path,
paths: Iterable[Path],
) -> Path:
common_ancestor: Optional[Path] = None
common_ancestor: Path | None = None
for path in paths:
if not path.exists():
continue
@ -144,12 +146,12 @@ def get_common_ancestor(
return common_ancestor
def get_dirs_from_args(args: Iterable[str]) -> List[Path]:
def get_dirs_from_args(args: Iterable[str]) -> list[Path]:
def is_option(x: str) -> bool:
return x.startswith("-")
return x.startswith('-')
def get_file_part_from_node_id(x: str) -> str:
return x.split("::")[0]
return x.split('::')[0]
def get_dir_from_path(path: Path) -> Path:
if path.is_dir():
@ -166,16 +168,16 @@ def get_dirs_from_args(args: Iterable[str]) -> List[Path]:
return [get_dir_from_path(path) for path in possible_paths if safe_exists(path)]
CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead."
CFG_PYTEST_SECTION = '[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead.'
def determine_setup(
*,
inifile: Optional[str],
inifile: str | None,
args: Sequence[str],
rootdir_cmd_arg: Optional[str],
rootdir_cmd_arg: str | None,
invocation_dir: Path,
) -> Tuple[Path, Optional[Path], Dict[str, Union[str, List[str]]]]:
) -> tuple[Path, Path | None, dict[str, str | list[str]]]:
"""Determine the rootdir, inifile and ini configuration values from the
command line arguments.
@ -192,7 +194,7 @@ def determine_setup(
dirs = get_dirs_from_args(args)
if inifile:
inipath_ = absolutepath(inifile)
inipath: Optional[Path] = inipath_
inipath: Path | None = inipath_
inicfg = load_config_dict_from_file(inipath_) or {}
if rootdir_cmd_arg is None:
rootdir = inipath_.parent
@ -201,7 +203,7 @@ def determine_setup(
rootdir, inipath, inicfg = locate_config(invocation_dir, [ancestor])
if rootdir is None and rootdir_cmd_arg is None:
for possible_rootdir in (ancestor, *ancestor.parents):
if (possible_rootdir / "setup.py").is_file():
if (possible_rootdir / 'setup.py').is_file():
rootdir = possible_rootdir
break
else:
@ -209,7 +211,7 @@ def determine_setup(
rootdir, inipath, inicfg = locate_config(invocation_dir, dirs)
if rootdir is None:
rootdir = get_common_ancestor(
invocation_dir, [invocation_dir, ancestor]
invocation_dir, [invocation_dir, ancestor],
)
if is_fs_root(rootdir):
rootdir = ancestor
@ -217,7 +219,7 @@ def determine_setup(
rootdir = absolutepath(os.path.expandvars(rootdir_cmd_arg))
if not rootdir.is_dir():
raise UsageError(
f"Directory '{rootdir}' not found. Check your '--rootdir' option."
f"Directory '{rootdir}' not found. Check your '--rootdir' option.",
)
assert rootdir is not None
return rootdir, inipath, inicfg or {}

View file

@ -1,9 +1,12 @@
# mypy: allow-untyped-defs
"""Interactive debugging with PDB, the Python Debugger."""
from __future__ import annotations
import argparse
import functools
import sys
import types
import unittest
from typing import Any
from typing import Callable
from typing import Generator
@ -13,7 +16,6 @@ from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
import unittest
from _pytest import outcomes
from _pytest._code import ExceptionInfo
@ -32,51 +34,51 @@ if TYPE_CHECKING:
from _pytest.runner import CallInfo
def _validate_usepdb_cls(value: str) -> Tuple[str, str]:
def _validate_usepdb_cls(value: str) -> tuple[str, str]:
"""Validate syntax of --pdbcls option."""
try:
modname, classname = value.split(":")
modname, classname = value.split(':')
except ValueError as e:
raise argparse.ArgumentTypeError(
f"{value!r} is not in the format 'modname:classname'"
f"{value!r} is not in the format 'modname:classname'",
) from e
return (modname, classname)
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group = parser.getgroup('general')
group._addoption(
"--pdb",
dest="usepdb",
action="store_true",
help="Start the interactive Python debugger on errors or KeyboardInterrupt",
'--pdb',
dest='usepdb',
action='store_true',
help='Start the interactive Python debugger on errors or KeyboardInterrupt',
)
group._addoption(
"--pdbcls",
dest="usepdb_cls",
metavar="modulename:classname",
'--pdbcls',
dest='usepdb_cls',
metavar='modulename:classname',
type=_validate_usepdb_cls,
help="Specify a custom interactive Python debugger for use with --pdb."
"For example: --pdbcls=IPython.terminal.debugger:TerminalPdb",
help='Specify a custom interactive Python debugger for use with --pdb.'
'For example: --pdbcls=IPython.terminal.debugger:TerminalPdb',
)
group._addoption(
"--trace",
dest="trace",
action="store_true",
help="Immediately break when running each test",
'--trace',
dest='trace',
action='store_true',
help='Immediately break when running each test',
)
def pytest_configure(config: Config) -> None:
import pdb
if config.getvalue("trace"):
config.pluginmanager.register(PdbTrace(), "pdbtrace")
if config.getvalue("usepdb"):
config.pluginmanager.register(PdbInvoke(), "pdbinvoke")
if config.getvalue('trace'):
config.pluginmanager.register(PdbTrace(), 'pdbtrace')
if config.getvalue('usepdb'):
config.pluginmanager.register(PdbInvoke(), 'pdbinvoke')
pytestPDB._saved.append(
(pdb.set_trace, pytestPDB._pluginmanager, pytestPDB._config)
(pdb.set_trace, pytestPDB._pluginmanager, pytestPDB._config),
)
pdb.set_trace = pytestPDB.set_trace
pytestPDB._pluginmanager = config.pluginmanager
@ -97,29 +99,29 @@ def pytest_configure(config: Config) -> None:
class pytestPDB:
"""Pseudo PDB that defers to the real pdb."""
_pluginmanager: Optional[PytestPluginManager] = None
_config: Optional[Config] = None
_saved: List[
Tuple[Callable[..., None], Optional[PytestPluginManager], Optional[Config]]
_pluginmanager: PytestPluginManager | None = None
_config: Config | None = None
_saved: list[
tuple[Callable[..., None], PytestPluginManager | None, Config | None]
] = []
_recursive_debug = 0
_wrapped_pdb_cls: Optional[Tuple[Type[Any], Type[Any]]] = None
_wrapped_pdb_cls: tuple[type[Any], type[Any]] | None = None
@classmethod
def _is_capturing(cls, capman: Optional["CaptureManager"]) -> Union[str, bool]:
def _is_capturing(cls, capman: CaptureManager | None) -> str | bool:
if capman:
return capman.is_capturing()
return False
@classmethod
def _import_pdb_cls(cls, capman: Optional["CaptureManager"]):
def _import_pdb_cls(cls, capman: CaptureManager | None):
if not cls._config:
import pdb
# Happens when using pytest.set_trace outside of a test.
return pdb.Pdb
usepdb_cls = cls._config.getvalue("usepdb_cls")
usepdb_cls = cls._config.getvalue('usepdb_cls')
if cls._wrapped_pdb_cls and cls._wrapped_pdb_cls[0] == usepdb_cls:
return cls._wrapped_pdb_cls[1]
@ -132,14 +134,14 @@ class pytestPDB:
mod = sys.modules[modname]
# Handle --pdbcls=pdb:pdb.Pdb (useful e.g. with pdbpp).
parts = classname.split(".")
parts = classname.split('.')
pdb_cls = getattr(mod, parts[0])
for part in parts[1:]:
pdb_cls = getattr(pdb_cls, part)
except Exception as exc:
value = ":".join((modname, classname))
value = ':'.join((modname, classname))
raise UsageError(
f"--pdbcls: could not import {value!r}: {exc}"
f'--pdbcls: could not import {value!r}: {exc}',
) from exc
else:
import pdb
@ -151,7 +153,7 @@ class pytestPDB:
return wrapped_cls
@classmethod
def _get_pdb_wrapper_class(cls, pdb_cls, capman: Optional["CaptureManager"]):
def _get_pdb_wrapper_class(cls, pdb_cls, capman: CaptureManager | None):
import _pytest.config
# Type ignored because mypy doesn't support "dynamic"
@ -176,18 +178,18 @@ class pytestPDB:
capman = self._pytest_capman
capturing = pytestPDB._is_capturing(capman)
if capturing:
if capturing == "global":
tw.sep(">", "PDB continue (IO-capturing resumed)")
if capturing == 'global':
tw.sep('>', 'PDB continue (IO-capturing resumed)')
else:
tw.sep(
">",
"PDB continue (IO-capturing resumed for %s)"
'>',
'PDB continue (IO-capturing resumed for %s)'
% capturing,
)
assert capman is not None
capman.resume()
else:
tw.sep(">", "PDB continue")
tw.sep('>', 'PDB continue')
assert cls._pluginmanager is not None
cls._pluginmanager.hook.pytest_leave_pdb(config=cls._config, pdb=self)
self._continued = True
@ -205,7 +207,7 @@ class pytestPDB:
ret = super().do_quit(arg)
if cls._recursive_debug == 0:
outcomes.exit("Quitting debugger")
outcomes.exit('Quitting debugger')
return ret
@ -231,7 +233,7 @@ class pytestPDB:
if f is None:
# Find last non-hidden frame.
i = max(0, len(stack) - 1)
while i and stack[i][0].f_locals.get("__tracebackhide__", False):
while i and stack[i][0].f_locals.get('__tracebackhide__', False):
i -= 1
return stack, i
@ -243,9 +245,9 @@ class pytestPDB:
import _pytest.config
if cls._pluginmanager is None:
capman: Optional[CaptureManager] = None
capman: CaptureManager | None = None
else:
capman = cls._pluginmanager.getplugin("capturemanager")
capman = cls._pluginmanager.getplugin('capturemanager')
if capman:
capman.suspend(in_=True)
@ -255,20 +257,20 @@ class pytestPDB:
if cls._recursive_debug == 0:
# Handle header similar to pdb.set_trace in py37+.
header = kwargs.pop("header", None)
header = kwargs.pop('header', None)
if header is not None:
tw.sep(">", header)
tw.sep('>', header)
else:
capturing = cls._is_capturing(capman)
if capturing == "global":
tw.sep(">", f"PDB {method} (IO-capturing turned off)")
if capturing == 'global':
tw.sep('>', f'PDB {method} (IO-capturing turned off)')
elif capturing:
tw.sep(
">",
f"PDB {method} (IO-capturing turned off for {capturing})",
'>',
f'PDB {method} (IO-capturing turned off for {capturing})',
)
else:
tw.sep(">", f"PDB {method}")
tw.sep('>', f'PDB {method}')
_pdb = cls._import_pdb_cls(capman)(**kwargs)
@ -280,15 +282,15 @@ class pytestPDB:
def set_trace(cls, *args, **kwargs) -> None:
"""Invoke debugging via ``Pdb.set_trace``, dropping any IO capturing."""
frame = sys._getframe().f_back
_pdb = cls._init_pdb("set_trace", *args, **kwargs)
_pdb = cls._init_pdb('set_trace', *args, **kwargs)
_pdb.set_trace(frame)
class PdbInvoke:
def pytest_exception_interact(
self, node: Node, call: "CallInfo[Any]", report: BaseReport
self, node: Node, call: CallInfo[Any], report: BaseReport,
) -> None:
capman = node.config.pluginmanager.getplugin("capturemanager")
capman = node.config.pluginmanager.getplugin('capturemanager')
if capman:
capman.suspend_global_capture(in_=True)
out, err = capman.read_global_capture()
@ -316,7 +318,7 @@ def wrap_pytest_function_for_tracing(pyfuncitem):
wrapper which actually enters pdb before calling the python function
itself, effectively leaving the user in the pdb prompt in the first
statement of the function."""
_pdb = pytestPDB._init_pdb("runcall")
_pdb = pytestPDB._init_pdb('runcall')
testfunction = pyfuncitem.obj
# we can't just return `partial(pdb.runcall, testfunction)` because (on
@ -333,35 +335,35 @@ def wrap_pytest_function_for_tracing(pyfuncitem):
def maybe_wrap_pytest_function_for_tracing(pyfuncitem):
"""Wrap the given pytestfunct item for tracing support if --trace was given in
the command line."""
if pyfuncitem.config.getvalue("trace"):
if pyfuncitem.config.getvalue('trace'):
wrap_pytest_function_for_tracing(pyfuncitem)
def _enter_pdb(
node: Node, excinfo: ExceptionInfo[BaseException], rep: BaseReport
node: Node, excinfo: ExceptionInfo[BaseException], rep: BaseReport,
) -> BaseReport:
# XXX we re-use the TerminalReporter's terminalwriter
# because this seems to avoid some encoding related troubles
# for not completely clear reasons.
tw = node.config.pluginmanager.getplugin("terminalreporter")._tw
tw = node.config.pluginmanager.getplugin('terminalreporter')._tw
tw.line()
showcapture = node.config.option.showcapture
for sectionname, content in (
("stdout", rep.capstdout),
("stderr", rep.capstderr),
("log", rep.caplog),
('stdout', rep.capstdout),
('stderr', rep.capstderr),
('log', rep.caplog),
):
if showcapture in (sectionname, "all") and content:
tw.sep(">", "captured " + sectionname)
if content[-1:] == "\n":
if showcapture in (sectionname, 'all') and content:
tw.sep('>', 'captured ' + sectionname)
if content[-1:] == '\n':
content = content[:-1]
tw.line(content)
tw.sep(">", "traceback")
tw.sep('>', 'traceback')
rep.toterminal(tw)
tw.sep(">", "entering PDB")
tw.sep('>', 'entering PDB')
tb = _postmortem_traceback(excinfo)
rep._pdbshown = True # type: ignore[attr-defined]
post_mortem(tb)
@ -386,8 +388,8 @@ def _postmortem_traceback(excinfo: ExceptionInfo[BaseException]) -> types.Traceb
def post_mortem(t: types.TracebackType) -> None:
p = pytestPDB._init_pdb("post_mortem")
p = pytestPDB._init_pdb('post_mortem')
p.reset()
p.interaction(None, t)
if p.quitting:
outcomes.exit("Quitting debugger")
outcomes.exit('Quitting debugger')

View file

@ -8,6 +8,7 @@ All constants defined in this module should be either instances of
:class:`PytestWarning`, or :class:`UnformattedWarning`
in case of warnings which need to format their messages.
"""
from __future__ import annotations
from warnings import warn
@ -19,50 +20,50 @@ from _pytest.warning_types import UnformattedWarning
# set of plugins which have been integrated into the core; we use this list to ignore
# them during registration to avoid conflicts
DEPRECATED_EXTERNAL_PLUGINS = {
"pytest_catchlog",
"pytest_capturelog",
"pytest_faulthandler",
'pytest_catchlog',
'pytest_capturelog',
'pytest_faulthandler',
}
# This can be* removed pytest 8, but it's harmless and common, so no rush to remove.
# * If you're in the future: "could have been".
YIELD_FIXTURE = PytestDeprecationWarning(
"@pytest.yield_fixture is deprecated.\n"
"Use @pytest.fixture instead; they are the same."
'@pytest.yield_fixture is deprecated.\n'
'Use @pytest.fixture instead; they are the same.',
)
# This deprecation is never really meant to be removed.
PRIVATE = PytestDeprecationWarning("A private pytest class or function was used.")
PRIVATE = PytestDeprecationWarning('A private pytest class or function was used.')
HOOK_LEGACY_PATH_ARG = UnformattedWarning(
PytestRemovedIn9Warning,
"The ({pylib_path_arg}: py.path.local) argument is deprecated, please use ({pathlib_path_arg}: pathlib.Path)\n"
"see https://docs.pytest.org/en/latest/deprecations.html"
"#py-path-local-arguments-for-hooks-replaced-with-pathlib-path",
'The ({pylib_path_arg}: py.path.local) argument is deprecated, please use ({pathlib_path_arg}: pathlib.Path)\n'
'see https://docs.pytest.org/en/latest/deprecations.html'
'#py-path-local-arguments-for-hooks-replaced-with-pathlib-path',
)
NODE_CTOR_FSPATH_ARG = UnformattedWarning(
PytestRemovedIn9Warning,
"The (fspath: py.path.local) argument to {node_type_name} is deprecated. "
"Please use the (path: pathlib.Path) argument instead.\n"
"See https://docs.pytest.org/en/latest/deprecations.html"
"#fspath-argument-for-node-constructors-replaced-with-pathlib-path",
'The (fspath: py.path.local) argument to {node_type_name} is deprecated. '
'Please use the (path: pathlib.Path) argument instead.\n'
'See https://docs.pytest.org/en/latest/deprecations.html'
'#fspath-argument-for-node-constructors-replaced-with-pathlib-path',
)
HOOK_LEGACY_MARKING = UnformattedWarning(
PytestDeprecationWarning,
"The hook{type} {fullname} uses old-style configuration options (marks or attributes).\n"
"Please use the pytest.hook{type}({hook_opts}) decorator instead\n"
" to configure the hooks.\n"
" See https://docs.pytest.org/en/latest/deprecations.html"
"#configuring-hook-specs-impls-using-markers",
'The hook{type} {fullname} uses old-style configuration options (marks or attributes).\n'
'Please use the pytest.hook{type}({hook_opts}) decorator instead\n'
' to configure the hooks.\n'
' See https://docs.pytest.org/en/latest/deprecations.html'
'#configuring-hook-specs-impls-using-markers',
)
MARKED_FIXTURE = PytestRemovedIn9Warning(
"Marks applied to fixtures have no effect\n"
"See docs: https://docs.pytest.org/en/stable/deprecations.html#applying-a-mark-to-a-fixture-function"
'Marks applied to fixtures have no effect\n'
'See docs: https://docs.pytest.org/en/stable/deprecations.html#applying-a-mark-to-a-fixture-function',
)
# You want to make some `__init__` or function "private".

View file

@ -1,15 +1,18 @@
# mypy: allow-untyped-defs
"""Discover and run doctests in modules and test files."""
from __future__ import annotations
import bdb
from contextlib import contextmanager
import functools
import inspect
import os
from pathlib import Path
import platform
import sys
import traceback
import types
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Dict
@ -23,7 +26,6 @@ from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
import warnings
from _pytest import outcomes
from _pytest._code.code import ExceptionInfo
@ -49,11 +51,11 @@ if TYPE_CHECKING:
import doctest
from typing import Self
DOCTEST_REPORT_CHOICE_NONE = "none"
DOCTEST_REPORT_CHOICE_CDIFF = "cdiff"
DOCTEST_REPORT_CHOICE_NDIFF = "ndiff"
DOCTEST_REPORT_CHOICE_UDIFF = "udiff"
DOCTEST_REPORT_CHOICE_ONLY_FIRST_FAILURE = "only_first_failure"
DOCTEST_REPORT_CHOICE_NONE = 'none'
DOCTEST_REPORT_CHOICE_CDIFF = 'cdiff'
DOCTEST_REPORT_CHOICE_NDIFF = 'ndiff'
DOCTEST_REPORT_CHOICE_UDIFF = 'udiff'
DOCTEST_REPORT_CHOICE_ONLY_FIRST_FAILURE = 'only_first_failure'
DOCTEST_REPORT_CHOICES = (
DOCTEST_REPORT_CHOICE_NONE,
@ -66,56 +68,56 @@ DOCTEST_REPORT_CHOICES = (
# Lazy definition of runner class
RUNNER_CLASS = None
# Lazy definition of output checker class
CHECKER_CLASS: Optional[Type["doctest.OutputChecker"]] = None
CHECKER_CLASS: type[doctest.OutputChecker] | None = None
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"doctest_optionflags",
"Option flags for doctests",
type="args",
default=["ELLIPSIS"],
'doctest_optionflags',
'Option flags for doctests',
type='args',
default=['ELLIPSIS'],
)
parser.addini(
"doctest_encoding", "Encoding used for doctest files", default="utf-8"
'doctest_encoding', 'Encoding used for doctest files', default='utf-8',
)
group = parser.getgroup("collect")
group = parser.getgroup('collect')
group.addoption(
"--doctest-modules",
action="store_true",
'--doctest-modules',
action='store_true',
default=False,
help="Run doctests in all .py modules",
dest="doctestmodules",
help='Run doctests in all .py modules',
dest='doctestmodules',
)
group.addoption(
"--doctest-report",
'--doctest-report',
type=str.lower,
default="udiff",
help="Choose another output format for diffs on doctest failure",
default='udiff',
help='Choose another output format for diffs on doctest failure',
choices=DOCTEST_REPORT_CHOICES,
dest="doctestreport",
dest='doctestreport',
)
group.addoption(
"--doctest-glob",
action="append",
'--doctest-glob',
action='append',
default=[],
metavar="pat",
help="Doctests file matching pattern, default: test*.txt",
dest="doctestglob",
metavar='pat',
help='Doctests file matching pattern, default: test*.txt',
dest='doctestglob',
)
group.addoption(
"--doctest-ignore-import-errors",
action="store_true",
'--doctest-ignore-import-errors',
action='store_true',
default=False,
help="Ignore doctest collection errors",
dest="doctest_ignore_import_errors",
help='Ignore doctest collection errors',
dest='doctest_ignore_import_errors',
)
group.addoption(
"--doctest-continue-on-failure",
action="store_true",
'--doctest-continue-on-failure',
action='store_true',
default=False,
help="For a given doctest, continue to run after the first failure",
dest="doctest_continue_on_failure",
help='For a given doctest, continue to run after the first failure',
dest='doctest_continue_on_failure',
)
@ -128,11 +130,11 @@ def pytest_unconfigure() -> None:
def pytest_collect_file(
file_path: Path,
parent: Collector,
) -> Optional[Union["DoctestModule", "DoctestTextfile"]]:
) -> DoctestModule | DoctestTextfile | None:
config = parent.config
if file_path.suffix == ".py":
if file_path.suffix == '.py':
if config.option.doctestmodules and not any(
(_is_setup_py(file_path), _is_main_py(file_path))
(_is_setup_py(file_path), _is_main_py(file_path)),
):
return DoctestModule.from_parent(parent, path=file_path)
elif _is_doctest(config, file_path, parent):
@ -141,26 +143,26 @@ def pytest_collect_file(
def _is_setup_py(path: Path) -> bool:
if path.name != "setup.py":
if path.name != 'setup.py':
return False
contents = path.read_bytes()
return b"setuptools" in contents or b"distutils" in contents
return b'setuptools' in contents or b'distutils' in contents
def _is_doctest(config: Config, path: Path, parent: Collector) -> bool:
if path.suffix in (".txt", ".rst") and parent.session.isinitpath(path):
if path.suffix in ('.txt', '.rst') and parent.session.isinitpath(path):
return True
globs = config.getoption("doctestglob") or ["test*.txt"]
globs = config.getoption('doctestglob') or ['test*.txt']
return any(fnmatch_ex(glob, path) for glob in globs)
def _is_main_py(path: Path) -> bool:
return path.name == "__main__.py"
return path.name == '__main__.py'
class ReprFailDoctest(TerminalRepr):
def __init__(
self, reprlocation_lines: Sequence[Tuple[ReprFileLocation, Sequence[str]]]
self, reprlocation_lines: Sequence[tuple[ReprFileLocation, Sequence[str]]],
) -> None:
self.reprlocation_lines = reprlocation_lines
@ -172,12 +174,12 @@ class ReprFailDoctest(TerminalRepr):
class MultipleDoctestFailures(Exception):
def __init__(self, failures: Sequence["doctest.DocTestFailure"]) -> None:
def __init__(self, failures: Sequence[doctest.DocTestFailure]) -> None:
super().__init__()
self.failures = failures
def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def _init_runner_class() -> type[doctest.DocTestRunner]:
import doctest
class PytestDoctestRunner(doctest.DebugRunner):
@ -189,8 +191,8 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def __init__(
self,
checker: Optional["doctest.OutputChecker"] = None,
verbose: Optional[bool] = None,
checker: doctest.OutputChecker | None = None,
verbose: bool | None = None,
optionflags: int = 0,
continue_on_failure: bool = True,
) -> None:
@ -200,8 +202,8 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def report_failure(
self,
out,
test: "doctest.DocTest",
example: "doctest.Example",
test: doctest.DocTest,
example: doctest.Example,
got: str,
) -> None:
failure = doctest.DocTestFailure(test, example, got)
@ -213,14 +215,14 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def report_unexpected_exception(
self,
out,
test: "doctest.DocTest",
example: "doctest.Example",
exc_info: Tuple[Type[BaseException], BaseException, types.TracebackType],
test: doctest.DocTest,
example: doctest.Example,
exc_info: tuple[type[BaseException], BaseException, types.TracebackType],
) -> None:
if isinstance(exc_info[1], OutcomeException):
raise exc_info[1]
if isinstance(exc_info[1], bdb.BdbQuit):
outcomes.exit("Quitting debugger")
outcomes.exit('Quitting debugger')
failure = doctest.UnexpectedException(test, example, exc_info)
if self.continue_on_failure:
out.append(failure)
@ -231,11 +233,11 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def _get_runner(
checker: Optional["doctest.OutputChecker"] = None,
verbose: Optional[bool] = None,
checker: doctest.OutputChecker | None = None,
verbose: bool | None = None,
optionflags: int = 0,
continue_on_failure: bool = True,
) -> "doctest.DocTestRunner":
) -> doctest.DocTestRunner:
# We need this in order to do a lazy import on doctest
global RUNNER_CLASS
if RUNNER_CLASS is None:
@ -254,9 +256,9 @@ class DoctestItem(Item):
def __init__(
self,
name: str,
parent: "Union[DoctestTextfile, DoctestModule]",
runner: "doctest.DocTestRunner",
dtest: "doctest.DocTest",
parent: Union[DoctestTextfile, DoctestModule],
runner: doctest.DocTestRunner,
dtest: doctest.DocTest,
) -> None:
super().__init__(name, parent)
self.runner = runner
@ -273,31 +275,31 @@ class DoctestItem(Item):
@classmethod
def from_parent( # type: ignore[override]
cls,
parent: "Union[DoctestTextfile, DoctestModule]",
parent: Union[DoctestTextfile, DoctestModule],
*,
name: str,
runner: "doctest.DocTestRunner",
dtest: "doctest.DocTest",
) -> "Self":
runner: doctest.DocTestRunner,
dtest: doctest.DocTest,
) -> Self:
# incompatible signature due to imposed limits on subclass
"""The public named constructor."""
return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest)
def _initrequest(self) -> None:
self.funcargs: Dict[str, object] = {}
self.funcargs: dict[str, object] = {}
self._request = TopRequest(self, _ispytest=True) # type: ignore[arg-type]
def setup(self) -> None:
self._request._fillfixtures()
globs = dict(getfixture=self._request.getfixturevalue)
for name, value in self._request.getfixturevalue("doctest_namespace").items():
for name, value in self._request.getfixturevalue('doctest_namespace').items():
globs[name] = value
self.dtest.globs.update(globs)
def runtest(self) -> None:
_check_all_skipped(self.dtest)
self._disable_output_capturing_for_darwin()
failures: List["doctest.DocTestFailure"] = []
failures: list[doctest.DocTestFailure] = []
# Type ignored because we change the type of `out` from what
# doctest expects.
self.runner.run(self.dtest, out=failures) # type: ignore[arg-type]
@ -306,9 +308,9 @@ class DoctestItem(Item):
def _disable_output_capturing_for_darwin(self) -> None:
"""Disable output capturing. Otherwise, stdout is lost to doctest (#985)."""
if platform.system() != "Darwin":
if platform.system() != 'Darwin':
return
capman = self.config.pluginmanager.getplugin("capturemanager")
capman = self.config.pluginmanager.getplugin('capturemanager')
if capman:
capman.suspend_global_capture(in_=True)
out, err = capman.read_global_capture()
@ -319,14 +321,14 @@ class DoctestItem(Item):
def repr_failure( # type: ignore[override]
self,
excinfo: ExceptionInfo[BaseException],
) -> Union[str, TerminalRepr]:
) -> str | TerminalRepr:
import doctest
failures: Optional[
Sequence[Union[doctest.DocTestFailure, doctest.UnexpectedException]]
] = None
failures: None | (
Sequence[doctest.DocTestFailure | doctest.UnexpectedException]
) = None
if isinstance(
excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException)
excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException),
):
failures = [excinfo.value]
elif isinstance(excinfo.value, MultipleDoctestFailures):
@ -348,43 +350,43 @@ class DoctestItem(Item):
# TODO: ReprFileLocation doesn't expect a None lineno.
reprlocation = ReprFileLocation(filename, lineno, message) # type: ignore[arg-type]
checker = _get_checker()
report_choice = _get_report_choice(self.config.getoption("doctestreport"))
report_choice = _get_report_choice(self.config.getoption('doctestreport'))
if lineno is not None:
assert failure.test.docstring is not None
lines = failure.test.docstring.splitlines(False)
# add line numbers to the left of the error message
assert test.lineno is not None
lines = [
"%03d %s" % (i + test.lineno + 1, x) for (i, x) in enumerate(lines)
'%03d %s' % (i + test.lineno + 1, x) for (i, x) in enumerate(lines)
]
# trim docstring error lines to 10
lines = lines[max(example.lineno - 9, 0) : example.lineno + 1]
lines = lines[max(example.lineno - 9, 0): example.lineno + 1]
else:
lines = [
"EXAMPLE LOCATION UNKNOWN, not showing all tests of that example"
'EXAMPLE LOCATION UNKNOWN, not showing all tests of that example',
]
indent = ">>>"
indent = '>>>'
for line in example.source.splitlines():
lines.append(f"??? {indent} {line}")
indent = "..."
lines.append(f'??? {indent} {line}')
indent = '...'
if isinstance(failure, doctest.DocTestFailure):
lines += checker.output_difference(
example, failure.got, report_choice
).split("\n")
example, failure.got, report_choice,
).split('\n')
else:
inner_excinfo = ExceptionInfo.from_exc_info(failure.exc_info)
lines += ["UNEXPECTED EXCEPTION: %s" % repr(inner_excinfo.value)]
lines += ['UNEXPECTED EXCEPTION: %s' % repr(inner_excinfo.value)]
lines += [
x.strip("\n") for x in traceback.format_exception(*failure.exc_info)
x.strip('\n') for x in traceback.format_exception(*failure.exc_info)
]
reprlocation_lines.append((reprlocation, lines))
return ReprFailDoctest(reprlocation_lines)
def reportinfo(self) -> Tuple[Union["os.PathLike[str]", str], Optional[int], str]:
return self.path, self.dtest.lineno, "[doctest] %s" % self.name
def reportinfo(self) -> tuple[os.PathLike[str] | str, int | None, str]:
return self.path, self.dtest.lineno, '[doctest] %s' % self.name
def _get_flag_lookup() -> Dict[str, int]:
def _get_flag_lookup() -> dict[str, int]:
import doctest
return dict(
@ -401,7 +403,7 @@ def _get_flag_lookup() -> Dict[str, int]:
def get_optionflags(config: Config) -> int:
optionflags_str = config.getini("doctest_optionflags")
optionflags_str = config.getini('doctest_optionflags')
flag_lookup_table = _get_flag_lookup()
flag_acc = 0
for flag in optionflags_str:
@ -410,11 +412,11 @@ def get_optionflags(config: Config) -> int:
def _get_continue_on_failure(config: Config) -> bool:
continue_on_failure: bool = config.getvalue("doctest_continue_on_failure")
continue_on_failure: bool = config.getvalue('doctest_continue_on_failure')
if continue_on_failure:
# We need to turn off this if we use pdb since we should stop at
# the first failure.
if config.getvalue("usepdb"):
if config.getvalue('usepdb'):
continue_on_failure = False
return continue_on_failure
@ -427,11 +429,11 @@ class DoctestTextfile(Module):
# Inspired by doctest.testfile; ideally we would use it directly,
# but it doesn't support passing a custom checker.
encoding = self.config.getini("doctest_encoding")
encoding = self.config.getini('doctest_encoding')
text = self.path.read_text(encoding)
filename = str(self.path)
name = self.path.name
globs = {"__name__": "__main__"}
globs = {'__name__': '__main__'}
optionflags = get_optionflags(self.config)
@ -446,25 +448,25 @@ class DoctestTextfile(Module):
test = parser.get_doctest(text, globs, name, filename, 0)
if test.examples:
yield DoctestItem.from_parent(
self, name=test.name, runner=runner, dtest=test
self, name=test.name, runner=runner, dtest=test,
)
def _check_all_skipped(test: "doctest.DocTest") -> None:
def _check_all_skipped(test: doctest.DocTest) -> None:
"""Raise pytest.skip() if all examples in the given DocTest have the SKIP
option set."""
import doctest
all_skipped = all(x.options.get(doctest.SKIP, False) for x in test.examples)
if all_skipped:
skip("all tests skipped by +SKIP option")
skip('all tests skipped by +SKIP option')
def _is_mocked(obj: object) -> bool:
"""Return if an object is possibly a mock object by checking the
existence of a highly improbable attribute."""
return (
safe_getattr(obj, "pytest_mock_example_attribute_that_shouldnt_exist", None)
safe_getattr(obj, 'pytest_mock_example_attribute_that_shouldnt_exist', None)
is not None
)
@ -476,7 +478,7 @@ def _patch_unwrap_mock_aware() -> Generator[None, None, None]:
real_unwrap = inspect.unwrap
def _mock_aware_unwrap(
func: Callable[..., Any], *, stop: Optional[Callable[[Any], Any]] = None
func: Callable[..., Any], *, stop: Callable[[Any], Any] | None = None,
) -> Any:
try:
if stop is None or stop is _is_mocked:
@ -485,9 +487,9 @@ def _patch_unwrap_mock_aware() -> Generator[None, None, None]:
return real_unwrap(func, stop=lambda obj: _is_mocked(obj) or _stop(func))
except Exception as e:
warnings.warn(
f"Got {e!r} when unwrapping {func!r}. This is usually caused "
f'Got {e!r} when unwrapping {func!r}. This is usually caused '
"by a violation of Python's object protocol; see e.g. "
"https://github.com/pytest-dev/pytest/issues/5080",
'https://github.com/pytest-dev/pytest/issues/5080',
PytestWarning,
)
raise
@ -518,9 +520,9 @@ class DoctestModule(Module):
line number is returned. This will be reported upstream. #8796
"""
if isinstance(obj, property):
obj = getattr(obj, "fget", obj)
obj = getattr(obj, 'fget', obj)
if hasattr(obj, "__wrapped__"):
if hasattr(obj, '__wrapped__'):
# Get the main obj in case of it being wrapped
obj = inspect.unwrap(obj)
@ -531,14 +533,14 @@ class DoctestModule(Module):
)
def _find(
self, tests, obj, name, module, source_lines, globs, seen
self, tests, obj, name, module, source_lines, globs, seen,
) -> None:
if _is_mocked(obj):
return
with _patch_unwrap_mock_aware():
# Type ignored because this is a private function.
super()._find( # type:ignore[misc]
tests, obj, name, module, source_lines, globs, seen
tests, obj, name, module, source_lines, globs, seen,
)
if sys.version_info < (3, 13):
@ -561,8 +563,8 @@ class DoctestModule(Module):
try:
module = self.obj
except Collector.CollectError:
if self.config.getvalue("doctest_ignore_import_errors"):
skip("unable to import module %r" % self.path)
if self.config.getvalue('doctest_ignore_import_errors'):
skip('unable to import module %r' % self.path)
else:
raise
@ -583,11 +585,11 @@ class DoctestModule(Module):
for test in finder.find(module, module.__name__):
if test.examples: # skip empty doctests
yield DoctestItem.from_parent(
self, name=test.name, runner=runner, dtest=test
self, name=test.name, runner=runner, dtest=test,
)
def _init_checker_class() -> Type["doctest.OutputChecker"]:
def _init_checker_class() -> type[doctest.OutputChecker]:
import doctest
import re
@ -633,7 +635,7 @@ def _init_checker_class() -> Type["doctest.OutputChecker"]:
return False
def remove_prefixes(regex: Pattern[str], txt: str) -> str:
return re.sub(regex, r"\1\2", txt)
return re.sub(regex, r'\1\2', txt)
if allow_unicode:
want = remove_prefixes(self._unicode_literal_re, want)
@ -655,10 +657,10 @@ def _init_checker_class() -> Type["doctest.OutputChecker"]:
return got
offset = 0
for w, g in zip(wants, gots):
fraction: Optional[str] = w.group("fraction")
exponent: Optional[str] = w.group("exponent1")
fraction: str | None = w.group('fraction')
exponent: str | None = w.group('exponent1')
if exponent is None:
exponent = w.group("exponent2")
exponent = w.group('exponent2')
precision = 0 if fraction is None else len(fraction)
if exponent is not None:
precision -= int(exponent)
@ -667,7 +669,7 @@ def _init_checker_class() -> Type["doctest.OutputChecker"]:
# got with the text we want, so that it will match when we
# check the string literally.
got = (
got[: g.start() + offset] + w.group() + got[g.end() + offset :]
got[: g.start() + offset] + w.group() + got[g.end() + offset:]
)
offset += w.end() - w.start() - (g.end() - g.start())
return got
@ -675,7 +677,7 @@ def _init_checker_class() -> Type["doctest.OutputChecker"]:
return LiteralsOutputChecker
def _get_checker() -> "doctest.OutputChecker":
def _get_checker() -> doctest.OutputChecker:
"""Return a doctest.OutputChecker subclass that supports some
additional options:
@ -699,21 +701,21 @@ def _get_allow_unicode_flag() -> int:
"""Register and return the ALLOW_UNICODE flag."""
import doctest
return doctest.register_optionflag("ALLOW_UNICODE")
return doctest.register_optionflag('ALLOW_UNICODE')
def _get_allow_bytes_flag() -> int:
"""Register and return the ALLOW_BYTES flag."""
import doctest
return doctest.register_optionflag("ALLOW_BYTES")
return doctest.register_optionflag('ALLOW_BYTES')
def _get_number_flag() -> int:
"""Register and return the NUMBER flag."""
import doctest
return doctest.register_optionflag("NUMBER")
return doctest.register_optionflag('NUMBER')
def _get_report_choice(key: str) -> int:
@ -733,8 +735,8 @@ def _get_report_choice(key: str) -> int:
}[key]
@fixture(scope="session")
def doctest_namespace() -> Dict[str, Any]:
@fixture(scope='session')
def doctest_namespace() -> dict[str, Any]:
"""Fixture that returns a :py:class:`dict` that will be injected into the
namespace of doctests.

View file

@ -1,12 +1,14 @@
from __future__ import annotations
import os
import sys
from typing import Generator
import pytest
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.nodes import Item
from _pytest.stash import StashKey
import pytest
fault_handler_original_stderr_fd_key = StashKey[int]()
@ -15,10 +17,10 @@ fault_handler_stderr_fd_key = StashKey[int]()
def pytest_addoption(parser: Parser) -> None:
help = (
"Dump the traceback of all threads if a test takes "
"more than TIMEOUT seconds to finish"
'Dump the traceback of all threads if a test takes '
'more than TIMEOUT seconds to finish'
)
parser.addini("faulthandler_timeout", help, default=0.0)
parser.addini('faulthandler_timeout', help, default=0.0)
def pytest_configure(config: Config) -> None:
@ -66,7 +68,7 @@ def get_stderr_fileno() -> int:
def get_timeout_config_value(config: Config) -> float:
return float(config.getini("faulthandler_timeout") or 0.0)
return float(config.getini('faulthandler_timeout') or 0.0)
@pytest.hookimpl(wrapper=True, trylast=True)

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,6 @@
"""Provides a function to report all internal modules for using freezing
tools."""
from __future__ import annotations
import types
from typing import Iterator
@ -7,7 +8,7 @@ from typing import List
from typing import Union
def freeze_includes() -> List[str]:
def freeze_includes() -> list[str]:
"""Return a list of module names used by pytest that should be
included by cx_freeze."""
import _pytest
@ -17,8 +18,8 @@ def freeze_includes() -> List[str]:
def _iter_all_modules(
package: Union[str, types.ModuleType],
prefix: str = "",
package: str | types.ModuleType,
prefix: str = '',
) -> Iterator[str]:
"""Iterate over the names of all modules that can be found in the given
package, recursively.
@ -36,10 +37,10 @@ def _iter_all_modules(
# Type ignored because typeshed doesn't define ModuleType.__path__
# (only defined on packages).
package_path = package.__path__ # type: ignore[attr-defined]
path, prefix = package_path[0], package.__name__ + "."
path, prefix = package_path[0], package.__name__ + '.'
for _, name, is_package in pkgutil.iter_modules([path]):
if is_package:
for m in _iter_all_modules(os.path.join(path, name), prefix=name + "."):
for m in _iter_all_modules(os.path.join(path, name), prefix=name + '.'):
yield prefix + m
else:
yield prefix + name

View file

@ -1,19 +1,21 @@
# mypy: allow-untyped-defs
"""Version info, help messages, tracing configuration."""
from argparse import Action
from __future__ import annotations
import os
import sys
from argparse import Action
from typing import Generator
from typing import List
from typing import Optional
from typing import Union
import pytest
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import PrintHelp
from _pytest.config.argparsing import Parser
from _pytest.terminal import TerminalReporter
import pytest
class HelpAction(Action):
@ -40,63 +42,63 @@ class HelpAction(Action):
setattr(namespace, self.dest, self.const)
# We should only skip the rest of the parsing after preparse is done.
if getattr(parser._parser, "after_preparse", False):
if getattr(parser._parser, 'after_preparse', False):
raise PrintHelp
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group = parser.getgroup('debugconfig')
group.addoption(
"--version",
"-V",
action="count",
'--version',
'-V',
action='count',
default=0,
dest="version",
help="Display pytest version and information about plugins. "
"When given twice, also display information about plugins.",
dest='version',
help='Display pytest version and information about plugins. '
'When given twice, also display information about plugins.',
)
group._addoption(
"-h",
"--help",
'-h',
'--help',
action=HelpAction,
dest="help",
help="Show help message and configuration info",
dest='help',
help='Show help message and configuration info',
)
group._addoption(
"-p",
action="append",
dest="plugins",
'-p',
action='append',
dest='plugins',
default=[],
metavar="name",
help="Early-load given plugin module name or entry point (multi-allowed). "
"To avoid loading of plugins, use the `no:` prefix, e.g. "
"`no:doctest`.",
metavar='name',
help='Early-load given plugin module name or entry point (multi-allowed). '
'To avoid loading of plugins, use the `no:` prefix, e.g. '
'`no:doctest`.',
)
group.addoption(
"--traceconfig",
"--trace-config",
action="store_true",
'--traceconfig',
'--trace-config',
action='store_true',
default=False,
help="Trace considerations of conftest.py files",
help='Trace considerations of conftest.py files',
)
group.addoption(
"--debug",
action="store",
nargs="?",
const="pytestdebug.log",
dest="debug",
metavar="DEBUG_FILE_NAME",
help="Store internal tracing debug information in this log file. "
'--debug',
action='store',
nargs='?',
const='pytestdebug.log',
dest='debug',
metavar='DEBUG_FILE_NAME',
help='Store internal tracing debug information in this log file. '
"This file is opened with 'w' and truncated as a result, care advised. "
"Default: pytestdebug.log.",
'Default: pytestdebug.log.',
)
group._addoption(
"-o",
"--override-ini",
dest="override_ini",
action="append",
'-o',
'--override-ini',
dest='override_ini',
action='append',
help='Override ini option with "option=value" style, '
"e.g. `-o xfail_strict=True -o cache_dir=cache`.",
'e.g. `-o xfail_strict=True -o cache_dir=cache`.',
)
@ -107,24 +109,24 @@ def pytest_cmdline_parse() -> Generator[None, Config, Config]:
if config.option.debug:
# --debug | --debug <file.log> was provided.
path = config.option.debug
debugfile = open(path, "w", encoding="utf-8")
debugfile = open(path, 'w', encoding='utf-8')
debugfile.write(
"versions pytest-{}, "
"python-{}\ninvocation_dir={}\ncwd={}\nargs={}\n\n".format(
'versions pytest-{}, '
'python-{}\ninvocation_dir={}\ncwd={}\nargs={}\n\n'.format(
pytest.__version__,
".".join(map(str, sys.version_info)),
'.'.join(map(str, sys.version_info)),
config.invocation_params.dir,
os.getcwd(),
config.invocation_params.args,
)
),
)
config.trace.root.setwriter(debugfile.write)
undo_tracing = config.pluginmanager.enable_tracing()
sys.stderr.write("writing pytest debug information to %s\n" % path)
sys.stderr.write('writing pytest debug information to %s\n' % path)
def unset_tracing() -> None:
debugfile.close()
sys.stderr.write("wrote pytest debug information to %s\n" % debugfile.name)
sys.stderr.write('wrote pytest debug information to %s\n' % debugfile.name)
config.trace.root.setwriter(None)
undo_tracing()
@ -136,17 +138,17 @@ def pytest_cmdline_parse() -> Generator[None, Config, Config]:
def showversion(config: Config) -> None:
if config.option.version > 1:
sys.stdout.write(
f"This is pytest version {pytest.__version__}, imported from {pytest.__file__}\n"
f'This is pytest version {pytest.__version__}, imported from {pytest.__file__}\n',
)
plugininfo = getpluginversioninfo(config)
if plugininfo:
for line in plugininfo:
sys.stdout.write(line + "\n")
sys.stdout.write(line + '\n')
else:
sys.stdout.write(f"pytest {pytest.__version__}\n")
sys.stdout.write(f'pytest {pytest.__version__}\n')
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.version > 0:
showversion(config)
return 0
@ -161,30 +163,30 @@ def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
def showhelp(config: Config) -> None:
import textwrap
reporter: Optional[TerminalReporter] = config.pluginmanager.get_plugin(
"terminalreporter"
reporter: TerminalReporter | None = config.pluginmanager.get_plugin(
'terminalreporter',
)
assert reporter is not None
tw = reporter._tw
tw.write(config._parser.optparser.format_help())
tw.line()
tw.line(
"[pytest] ini-options in the first "
"pytest.ini|tox.ini|setup.cfg|pyproject.toml file found:"
'[pytest] ini-options in the first '
'pytest.ini|tox.ini|setup.cfg|pyproject.toml file found:',
)
tw.line()
columns = tw.fullwidth # costly call
indent_len = 24 # based on argparse's max_help_position=24
indent = " " * indent_len
indent = ' ' * indent_len
for name in config._parser._ininames:
help, type, default = config._parser._inidict[name]
if type is None:
type = "string"
type = 'string'
if help is None:
raise TypeError(f"help argument cannot be None for {name}")
spec = f"{name} ({type}):"
tw.write(" %s" % spec)
raise TypeError(f'help argument cannot be None for {name}')
spec = f'{name} ({type}):'
tw.write(' %s' % spec)
spec_len = len(spec)
if spec_len > (indent_len - 3):
# Display help starting at a new line.
@ -201,7 +203,7 @@ def showhelp(config: Config) -> None:
tw.line(line)
else:
# Display help starting after the spec, following lines indented.
tw.write(" " * (indent_len - spec_len - 2))
tw.write(' ' * (indent_len - spec_len - 2))
wrapped = textwrap.wrap(help, columns - indent_len, break_on_hyphens=False)
if wrapped:
@ -210,62 +212,62 @@ def showhelp(config: Config) -> None:
tw.line(indent + line)
tw.line()
tw.line("Environment variables:")
tw.line('Environment variables:')
vars = [
("PYTEST_ADDOPTS", "Extra command line options"),
("PYTEST_PLUGINS", "Comma-separated plugins to load during startup"),
("PYTEST_DISABLE_PLUGIN_AUTOLOAD", "Set to disable plugin auto-loading"),
("PYTEST_DEBUG", "Set to enable debug tracing of pytest's internals"),
('PYTEST_ADDOPTS', 'Extra command line options'),
('PYTEST_PLUGINS', 'Comma-separated plugins to load during startup'),
('PYTEST_DISABLE_PLUGIN_AUTOLOAD', 'Set to disable plugin auto-loading'),
('PYTEST_DEBUG', "Set to enable debug tracing of pytest's internals"),
]
for name, help in vars:
tw.line(f" {name:<24} {help}")
tw.line(f' {name:<24} {help}')
tw.line()
tw.line()
tw.line("to see available markers type: pytest --markers")
tw.line("to see available fixtures type: pytest --fixtures")
tw.line('to see available markers type: pytest --markers')
tw.line('to see available fixtures type: pytest --fixtures')
tw.line(
"(shown according to specified file_or_dir or current dir "
'(shown according to specified file_or_dir or current dir '
"if not specified; fixtures with leading '_' are only shown "
"with the '-v' option"
"with the '-v' option",
)
for warningreport in reporter.stats.get("warnings", []):
tw.line("warning : " + warningreport.message, red=True)
for warningreport in reporter.stats.get('warnings', []):
tw.line('warning : ' + warningreport.message, red=True)
return
conftest_options = [("pytest_plugins", "list of plugin names to load")]
conftest_options = [('pytest_plugins', 'list of plugin names to load')]
def getpluginversioninfo(config: Config) -> List[str]:
def getpluginversioninfo(config: Config) -> list[str]:
lines = []
plugininfo = config.pluginmanager.list_plugin_distinfo()
if plugininfo:
lines.append("setuptools registered plugins:")
lines.append('setuptools registered plugins:')
for plugin, dist in plugininfo:
loc = getattr(plugin, "__file__", repr(plugin))
content = f"{dist.project_name}-{dist.version} at {loc}"
lines.append(" " + content)
loc = getattr(plugin, '__file__', repr(plugin))
content = f'{dist.project_name}-{dist.version} at {loc}'
lines.append(' ' + content)
return lines
def pytest_report_header(config: Config) -> List[str]:
def pytest_report_header(config: Config) -> list[str]:
lines = []
if config.option.debug or config.option.traceconfig:
lines.append(f"using: pytest-{pytest.__version__}")
lines.append(f'using: pytest-{pytest.__version__}')
verinfo = getpluginversioninfo(config)
if verinfo:
lines.extend(verinfo)
if config.option.traceconfig:
lines.append("active plugins:")
lines.append('active plugins:')
items = config.pluginmanager.list_name_plugin()
for name, plugin in items:
if hasattr(plugin, "__file__"):
if hasattr(plugin, '__file__'):
r = plugin.__file__
else:
r = repr(plugin)
lines.append(f" {name:<20}: {r}")
lines.append(f' {name:<20}: {r}')
return lines

View file

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
"""Hook specifications for pytest plugins which are invoked by pytest itself
and by builtin plugins."""
from __future__ import annotations
from pathlib import Path
from typing import Any
from typing import Dict
@ -45,7 +47,7 @@ if TYPE_CHECKING:
from _pytest.terminal import TestShortLogReport
hookspec = HookspecMarker("pytest")
hookspec = HookspecMarker('pytest')
# -------------------------------------------------------------------------
# Initialization hooks called for every plugin
@ -53,7 +55,7 @@ hookspec = HookspecMarker("pytest")
@hookspec(historic=True)
def pytest_addhooks(pluginmanager: "PytestPluginManager") -> None:
def pytest_addhooks(pluginmanager: PytestPluginManager) -> None:
"""Called at plugin registration time to allow adding new hooks via a call to
:func:`pluginmanager.add_hookspecs(module_or_class, prefix) <pytest.PytestPluginManager.add_hookspecs>`.
@ -72,9 +74,9 @@ def pytest_addhooks(pluginmanager: "PytestPluginManager") -> None:
@hookspec(historic=True)
def pytest_plugin_registered(
plugin: "_PluggyPlugin",
plugin: _PluggyPlugin,
plugin_name: str,
manager: "PytestPluginManager",
manager: PytestPluginManager,
) -> None:
"""A new pytest plugin got registered.
@ -96,7 +98,7 @@ def pytest_plugin_registered(
@hookspec(historic=True)
def pytest_addoption(parser: "Parser", pluginmanager: "PytestPluginManager") -> None:
def pytest_addoption(parser: Parser, pluginmanager: PytestPluginManager) -> None:
"""Register argparse-style options and ini-style config values,
called once at the beginning of a test run.
@ -137,7 +139,7 @@ def pytest_addoption(parser: "Parser", pluginmanager: "PytestPluginManager") ->
@hookspec(historic=True)
def pytest_configure(config: "Config") -> None:
def pytest_configure(config: Config) -> None:
"""Allow plugins and conftest files to perform initial configuration.
.. note::
@ -162,8 +164,8 @@ def pytest_configure(config: "Config") -> None:
@hookspec(firstresult=True)
def pytest_cmdline_parse(
pluginmanager: "PytestPluginManager", args: List[str]
) -> Optional["Config"]:
pluginmanager: PytestPluginManager, args: list[str],
) -> Config | None:
"""Return an initialized :class:`~pytest.Config`, parsing the specified args.
Stops at first non-None result, see :ref:`firstresult`.
@ -185,7 +187,7 @@ def pytest_cmdline_parse(
def pytest_load_initial_conftests(
early_config: "Config", parser: "Parser", args: List[str]
early_config: Config, parser: Parser, args: list[str],
) -> None:
"""Called to implement the loading of :ref:`initial conftest files
<pluginorder>` ahead of command line option parsing.
@ -202,7 +204,7 @@ def pytest_load_initial_conftests(
@hookspec(firstresult=True)
def pytest_cmdline_main(config: "Config") -> Optional[Union["ExitCode", int]]:
def pytest_cmdline_main(config: Config) -> ExitCode | int | None:
"""Called for performing the main command line action.
The default implementation will invoke the configure hooks and
@ -226,7 +228,7 @@ def pytest_cmdline_main(config: "Config") -> Optional[Union["ExitCode", int]]:
@hookspec(firstresult=True)
def pytest_collection(session: "Session") -> Optional[object]:
def pytest_collection(session: Session) -> object | None:
"""Perform the collection phase for the given session.
Stops at first non-None result, see :ref:`firstresult`.
@ -268,7 +270,7 @@ def pytest_collection(session: "Session") -> Optional[object]:
def pytest_collection_modifyitems(
session: "Session", config: "Config", items: List["Item"]
session: Session, config: Config, items: list[Item],
) -> None:
"""Called after collection has been performed. May filter or re-order
the items in-place.
@ -284,7 +286,7 @@ def pytest_collection_modifyitems(
"""
def pytest_collection_finish(session: "Session") -> None:
def pytest_collection_finish(session: Session) -> None:
"""Called after collection has been performed and modified.
:param session: The pytest session object.
@ -298,8 +300,8 @@ def pytest_collection_finish(session: "Session") -> None:
@hookspec(firstresult=True)
def pytest_ignore_collect(
collection_path: Path, path: "LEGACY_PATH", config: "Config"
) -> Optional[bool]:
collection_path: Path, path: LEGACY_PATH, config: Config,
) -> bool | None:
"""Return True to prevent considering this path for collection.
This hook is consulted for all files and directories prior to calling
@ -327,7 +329,7 @@ def pytest_ignore_collect(
@hookspec(firstresult=True)
def pytest_collect_directory(path: Path, parent: "Collector") -> "Optional[Collector]":
def pytest_collect_directory(path: Path, parent: Collector) -> Optional[Collector]:
"""Create a :class:`~pytest.Collector` for the given directory, or None if
not relevant.
@ -356,8 +358,8 @@ def pytest_collect_directory(path: Path, parent: "Collector") -> "Optional[Colle
def pytest_collect_file(
file_path: Path, path: "LEGACY_PATH", parent: "Collector"
) -> "Optional[Collector]":
file_path: Path, path: LEGACY_PATH, parent: Collector,
) -> Optional[Collector]:
"""Create a :class:`~pytest.Collector` for the given path, or None if not relevant.
For best results, the returned collector should be a subclass of
@ -384,7 +386,7 @@ def pytest_collect_file(
# logging hooks for collection
def pytest_collectstart(collector: "Collector") -> None:
def pytest_collectstart(collector: Collector) -> None:
"""Collector starts collecting.
:param collector:
@ -399,7 +401,7 @@ def pytest_collectstart(collector: "Collector") -> None:
"""
def pytest_itemcollected(item: "Item") -> None:
def pytest_itemcollected(item: Item) -> None:
"""We just collected a test item.
:param item:
@ -413,7 +415,7 @@ def pytest_itemcollected(item: "Item") -> None:
"""
def pytest_collectreport(report: "CollectReport") -> None:
def pytest_collectreport(report: CollectReport) -> None:
"""Collector finished collecting.
:param report:
@ -428,7 +430,7 @@ def pytest_collectreport(report: "CollectReport") -> None:
"""
def pytest_deselected(items: Sequence["Item"]) -> None:
def pytest_deselected(items: Sequence[Item]) -> None:
"""Called for deselected test items, e.g. by keyword.
May be called multiple times.
@ -444,7 +446,7 @@ def pytest_deselected(items: Sequence["Item"]) -> None:
@hookspec(firstresult=True)
def pytest_make_collect_report(collector: "Collector") -> "Optional[CollectReport]":
def pytest_make_collect_report(collector: Collector) -> Optional[CollectReport]:
"""Perform :func:`collector.collect() <pytest.Collector.collect>` and return
a :class:`~pytest.CollectReport`.
@ -469,8 +471,8 @@ def pytest_make_collect_report(collector: "Collector") -> "Optional[CollectRepor
@hookspec(firstresult=True)
def pytest_pycollect_makemodule(
module_path: Path, path: "LEGACY_PATH", parent
) -> Optional["Module"]:
module_path: Path, path: LEGACY_PATH, parent,
) -> Module | None:
"""Return a :class:`pytest.Module` collector or None for the given path.
This hook will be called for each matching test module path.
@ -499,8 +501,8 @@ def pytest_pycollect_makemodule(
@hookspec(firstresult=True)
def pytest_pycollect_makeitem(
collector: Union["Module", "Class"], name: str, obj: object
) -> Union[None, "Item", "Collector", List[Union["Item", "Collector"]]]:
collector: Module | Class, name: str, obj: object,
) -> None | Item | Collector | list[Item | Collector]:
"""Return a custom item/collector for a Python object in a module, or None.
Stops at first non-None result, see :ref:`firstresult`.
@ -524,7 +526,7 @@ def pytest_pycollect_makeitem(
@hookspec(firstresult=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
"""Call underlying test function.
Stops at first non-None result, see :ref:`firstresult`.
@ -541,7 +543,7 @@ def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
"""
def pytest_generate_tests(metafunc: "Metafunc") -> None:
def pytest_generate_tests(metafunc: Metafunc) -> None:
"""Generate (multiple) parametrized calls to a test function.
:param metafunc:
@ -558,8 +560,8 @@ def pytest_generate_tests(metafunc: "Metafunc") -> None:
@hookspec(firstresult=True)
def pytest_make_parametrize_id(
config: "Config", val: object, argname: str
) -> Optional[str]:
config: Config, val: object, argname: str,
) -> str | None:
"""Return a user-friendly string representation of the given ``val``
that will be used by @pytest.mark.parametrize calls, or None if the hook
doesn't know about ``val``.
@ -585,7 +587,7 @@ def pytest_make_parametrize_id(
@hookspec(firstresult=True)
def pytest_runtestloop(session: "Session") -> Optional[object]:
def pytest_runtestloop(session: Session) -> object | None:
"""Perform the main runtest loop (after collection finished).
The default hook implementation performs the runtest protocol for all items
@ -612,8 +614,8 @@ def pytest_runtestloop(session: "Session") -> Optional[object]:
@hookspec(firstresult=True)
def pytest_runtest_protocol(
item: "Item", nextitem: "Optional[Item]"
) -> Optional[object]:
item: Item, nextitem: Optional[Item],
) -> object | None:
"""Perform the runtest protocol for a single test item.
The default runtest protocol is this (see individual hooks for full details):
@ -654,7 +656,7 @@ def pytest_runtest_protocol(
def pytest_runtest_logstart(
nodeid: str, location: Tuple[str, Optional[int], str]
nodeid: str, location: tuple[str, int | None, str],
) -> None:
"""Called at the start of running the runtest protocol for a single item.
@ -674,7 +676,7 @@ def pytest_runtest_logstart(
def pytest_runtest_logfinish(
nodeid: str, location: Tuple[str, Optional[int], str]
nodeid: str, location: tuple[str, int | None, str],
) -> None:
"""Called at the end of running the runtest protocol for a single item.
@ -693,7 +695,7 @@ def pytest_runtest_logfinish(
"""
def pytest_runtest_setup(item: "Item") -> None:
def pytest_runtest_setup(item: Item) -> None:
"""Called to perform the setup phase for a test item.
The default implementation runs ``setup()`` on ``item`` and all of its
@ -712,7 +714,7 @@ def pytest_runtest_setup(item: "Item") -> None:
"""
def pytest_runtest_call(item: "Item") -> None:
def pytest_runtest_call(item: Item) -> None:
"""Called to run the test for test item (the call phase).
The default implementation calls ``item.runtest()``.
@ -728,7 +730,7 @@ def pytest_runtest_call(item: "Item") -> None:
"""
def pytest_runtest_teardown(item: "Item", nextitem: Optional["Item"]) -> None:
def pytest_runtest_teardown(item: Item, nextitem: Item | None) -> None:
"""Called to perform the teardown phase for a test item.
The default implementation runs the finalizers and calls ``teardown()``
@ -754,8 +756,8 @@ def pytest_runtest_teardown(item: "Item", nextitem: Optional["Item"]) -> None:
@hookspec(firstresult=True)
def pytest_runtest_makereport(
item: "Item", call: "CallInfo[None]"
) -> Optional["TestReport"]:
item: Item, call: CallInfo[None],
) -> TestReport | None:
"""Called to create a :class:`~pytest.TestReport` for each of
the setup, call and teardown runtest phases of a test item.
@ -774,7 +776,7 @@ def pytest_runtest_makereport(
"""
def pytest_runtest_logreport(report: "TestReport") -> None:
def pytest_runtest_logreport(report: TestReport) -> None:
"""Process the :class:`~pytest.TestReport` produced for each
of the setup, call and teardown runtest phases of an item.
@ -790,9 +792,9 @@ def pytest_runtest_logreport(report: "TestReport") -> None:
@hookspec(firstresult=True)
def pytest_report_to_serializable(
config: "Config",
report: Union["CollectReport", "TestReport"],
) -> Optional[Dict[str, Any]]:
config: Config,
report: CollectReport | TestReport,
) -> dict[str, Any] | None:
"""Serialize the given report object into a data structure suitable for
sending over the wire, e.g. converted to JSON.
@ -809,9 +811,9 @@ def pytest_report_to_serializable(
@hookspec(firstresult=True)
def pytest_report_from_serializable(
config: "Config",
data: Dict[str, Any],
) -> Optional[Union["CollectReport", "TestReport"]]:
config: Config,
data: dict[str, Any],
) -> CollectReport | TestReport | None:
"""Restore a report object previously serialized with
:hook:`pytest_report_to_serializable`.
@ -832,8 +834,8 @@ def pytest_report_from_serializable(
@hookspec(firstresult=True)
def pytest_fixture_setup(
fixturedef: "FixtureDef[Any]", request: "SubRequest"
) -> Optional[object]:
fixturedef: FixtureDef[Any], request: SubRequest,
) -> object | None:
"""Perform fixture setup execution.
:param fixturdef:
@ -860,7 +862,7 @@ def pytest_fixture_setup(
def pytest_fixture_post_finalizer(
fixturedef: "FixtureDef[Any]", request: "SubRequest"
fixturedef: FixtureDef[Any], request: SubRequest,
) -> None:
"""Called after fixture teardown, but before the cache is cleared, so
the fixture result ``fixturedef.cached_result`` is still available (not
@ -885,7 +887,7 @@ def pytest_fixture_post_finalizer(
# -------------------------------------------------------------------------
def pytest_sessionstart(session: "Session") -> None:
def pytest_sessionstart(session: Session) -> None:
"""Called after the ``Session`` object has been created and before performing collection
and entering the run test loop.
@ -899,8 +901,8 @@ def pytest_sessionstart(session: "Session") -> None:
def pytest_sessionfinish(
session: "Session",
exitstatus: Union[int, "ExitCode"],
session: Session,
exitstatus: int | ExitCode,
) -> None:
"""Called after whole test run finished, right before returning the exit status to the system.
@ -914,7 +916,7 @@ def pytest_sessionfinish(
"""
def pytest_unconfigure(config: "Config") -> None:
def pytest_unconfigure(config: Config) -> None:
"""Called before test process is exited.
:param config: The pytest config object.
@ -932,8 +934,8 @@ def pytest_unconfigure(config: "Config") -> None:
def pytest_assertrepr_compare(
config: "Config", op: str, left: object, right: object
) -> Optional[List[str]]:
config: Config, op: str, left: object, right: object,
) -> list[str] | None:
"""Return explanation for comparisons in failing assert expressions.
Return None for no custom explanation, otherwise return a list
@ -954,7 +956,7 @@ def pytest_assertrepr_compare(
"""
def pytest_assertion_pass(item: "Item", lineno: int, orig: str, expl: str) -> None:
def pytest_assertion_pass(item: Item, lineno: int, orig: str, expl: str) -> None:
"""Called whenever an assertion passes.
.. versionadded:: 5.0
@ -994,8 +996,8 @@ def pytest_assertion_pass(item: "Item", lineno: int, orig: str, expl: str) -> No
def pytest_report_header( # type:ignore[empty-body]
config: "Config", start_path: Path, startdir: "LEGACY_PATH"
) -> Union[str, List[str]]:
config: Config, start_path: Path, startdir: LEGACY_PATH,
) -> str | list[str]:
"""Return a string or list of strings to be displayed as header info for terminal reporting.
:param config: The pytest config object.
@ -1022,11 +1024,11 @@ def pytest_report_header( # type:ignore[empty-body]
def pytest_report_collectionfinish( # type:ignore[empty-body]
config: "Config",
config: Config,
start_path: Path,
startdir: "LEGACY_PATH",
items: Sequence["Item"],
) -> Union[str, List[str]]:
startdir: LEGACY_PATH,
items: Sequence[Item],
) -> str | list[str]:
"""Return a string or list of strings to be displayed after collection
has finished successfully.
@ -1060,8 +1062,8 @@ def pytest_report_collectionfinish( # type:ignore[empty-body]
@hookspec(firstresult=True)
def pytest_report_teststatus( # type:ignore[empty-body]
report: Union["CollectReport", "TestReport"], config: "Config"
) -> "TestShortLogReport | Tuple[str, str, Union[str, Tuple[str, Mapping[str, bool]]]]":
report: CollectReport | TestReport, config: Config,
) -> TestShortLogReport | Tuple[str, str, Union[str, Tuple[str, Mapping[str, bool]]]]:
"""Return result-category, shortletter and verbose word for status
reporting.
@ -1092,9 +1094,9 @@ def pytest_report_teststatus( # type:ignore[empty-body]
def pytest_terminal_summary(
terminalreporter: "TerminalReporter",
exitstatus: "ExitCode",
config: "Config",
terminalreporter: TerminalReporter,
exitstatus: ExitCode,
config: Config,
) -> None:
"""Add a section to terminal summary reporting.
@ -1114,10 +1116,10 @@ def pytest_terminal_summary(
@hookspec(historic=True)
def pytest_warning_recorded(
warning_message: "warnings.WarningMessage",
when: "Literal['config', 'collect', 'runtest']",
warning_message: warnings.WarningMessage,
when: Literal['config', 'collect', 'runtest'],
nodeid: str,
location: Optional[Tuple[str, int, str]],
location: tuple[str, int, str] | None,
) -> None:
"""Process a warning captured by the internal pytest warnings plugin.
@ -1158,8 +1160,8 @@ def pytest_warning_recorded(
def pytest_markeval_namespace( # type:ignore[empty-body]
config: "Config",
) -> Dict[str, Any]:
config: Config,
) -> dict[str, Any]:
"""Called when constructing the globals dictionary used for
evaluating string conditions in xfail/skipif markers.
@ -1187,9 +1189,9 @@ def pytest_markeval_namespace( # type:ignore[empty-body]
def pytest_internalerror(
excrepr: "ExceptionRepr",
excinfo: "ExceptionInfo[BaseException]",
) -> Optional[bool]:
excrepr: ExceptionRepr,
excinfo: ExceptionInfo[BaseException],
) -> bool | None:
"""Called for internal errors.
Return True to suppress the fallback handling of printing an
@ -1206,7 +1208,7 @@ def pytest_internalerror(
def pytest_keyboard_interrupt(
excinfo: "ExceptionInfo[Union[KeyboardInterrupt, Exit]]",
excinfo: ExceptionInfo[Union[KeyboardInterrupt, Exit]],
) -> None:
"""Called for keyboard interrupt.
@ -1220,9 +1222,9 @@ def pytest_keyboard_interrupt(
def pytest_exception_interact(
node: Union["Item", "Collector"],
call: "CallInfo[Any]",
report: Union["CollectReport", "TestReport"],
node: Item | Collector,
call: CallInfo[Any],
report: CollectReport | TestReport,
) -> None:
"""Called when an exception was raised which can potentially be
interactively handled.
@ -1251,7 +1253,7 @@ def pytest_exception_interact(
"""
def pytest_enter_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
def pytest_enter_pdb(config: Config, pdb: pdb.Pdb) -> None:
"""Called upon pdb.set_trace().
Can be used by plugins to take special action just before the python
@ -1267,7 +1269,7 @@ def pytest_enter_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
"""
def pytest_leave_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
def pytest_leave_pdb(config: Config, pdb: pdb.Pdb) -> None:
"""Called when leaving pdb (e.g. with continue after pdb.set_trace()).
Can be used by plugins to take special action just after the python

View file

@ -7,11 +7,14 @@ Based on initial code from Ross Lawley.
Output conforms to
https://github.com/jenkinsci/xunit-plugin/blob/master/src/main/resources/org/jenkinsci/plugins/xunit/types/model/xsd/junit-10.xsd
"""
from datetime import datetime
from __future__ import annotations
import functools
import os
import platform
import re
import xml.etree.ElementTree as ET
from datetime import datetime
from typing import Callable
from typing import Dict
from typing import List
@ -19,8 +22,8 @@ from typing import Match
from typing import Optional
from typing import Tuple
from typing import Union
import xml.etree.ElementTree as ET
import pytest
from _pytest import nodes
from _pytest import timing
from _pytest._code.code import ExceptionRepr
@ -32,10 +35,9 @@ from _pytest.fixtures import FixtureRequest
from _pytest.reports import TestReport
from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter
import pytest
xml_key = StashKey["LogXML"]()
xml_key = StashKey['LogXML']()
def bin_xml_escape(arg: object) -> str:
@ -52,15 +54,15 @@ def bin_xml_escape(arg: object) -> str:
def repl(matchobj: Match[str]) -> str:
i = ord(matchobj.group())
if i <= 0xFF:
return "#x%02X" % i
return '#x%02X' % i
else:
return "#x%04X" % i
return '#x%04X' % i
# The spec range of valid chars is:
# Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
# For an unknown(?) reason, we disallow #x7F (DEL) as well.
illegal_xml_re = (
"[^\u0009\u000A\u000D\u0020-\u007E\u0080-\uD7FF\uE000-\uFFFD\u10000-\u10FFFF]"
'[^\u0009\u000A\u000D\u0020-\u007E\u0080-\uD7FF\uE000-\uFFFD\u10000-\u10FFFF]'
)
return re.sub(illegal_xml_re, repl, str(arg))
@ -76,27 +78,27 @@ def merge_family(left, right) -> None:
families = {}
families["_base"] = {"testcase": ["classname", "name"]}
families["_base_legacy"] = {"testcase": ["file", "line", "url"]}
families['_base'] = {'testcase': ['classname', 'name']}
families['_base_legacy'] = {'testcase': ['file', 'line', 'url']}
# xUnit 1.x inherits legacy attributes.
families["xunit1"] = families["_base"].copy()
merge_family(families["xunit1"], families["_base_legacy"])
families['xunit1'] = families['_base'].copy()
merge_family(families['xunit1'], families['_base_legacy'])
# xUnit 2.x uses strict base attributes.
families["xunit2"] = families["_base"]
families['xunit2'] = families['_base']
class _NodeReporter:
def __init__(self, nodeid: Union[str, TestReport], xml: "LogXML") -> None:
def __init__(self, nodeid: str | TestReport, xml: LogXML) -> None:
self.id = nodeid
self.xml = xml
self.add_stats = self.xml.add_stats
self.family = self.xml.family
self.duration = 0.0
self.properties: List[Tuple[str, str]] = []
self.nodes: List[ET.Element] = []
self.attrs: Dict[str, str] = {}
self.properties: list[tuple[str, str]] = []
self.nodes: list[ET.Element] = []
self.attrs: dict[str, str] = {}
def append(self, node: ET.Element) -> None:
self.xml.add_stats(node.tag)
@ -108,12 +110,12 @@ class _NodeReporter:
def add_attribute(self, name: str, value: object) -> None:
self.attrs[str(name)] = bin_xml_escape(value)
def make_properties_node(self) -> Optional[ET.Element]:
def make_properties_node(self) -> ET.Element | None:
"""Return a Junit node containing custom properties, if any."""
if self.properties:
properties = ET.Element("properties")
properties = ET.Element('properties')
for name, value in self.properties:
properties.append(ET.Element("property", name=name, value=value))
properties.append(ET.Element('property', name=name, value=value))
return properties
return None
@ -123,39 +125,39 @@ class _NodeReporter:
classnames = names[:-1]
if self.xml.prefix:
classnames.insert(0, self.xml.prefix)
attrs: Dict[str, str] = {
"classname": ".".join(classnames),
"name": bin_xml_escape(names[-1]),
"file": testreport.location[0],
attrs: dict[str, str] = {
'classname': '.'.join(classnames),
'name': bin_xml_escape(names[-1]),
'file': testreport.location[0],
}
if testreport.location[1] is not None:
attrs["line"] = str(testreport.location[1])
if hasattr(testreport, "url"):
attrs["url"] = testreport.url
attrs['line'] = str(testreport.location[1])
if hasattr(testreport, 'url'):
attrs['url'] = testreport.url
self.attrs = attrs
self.attrs.update(existing_attrs) # Restore any user-defined attributes.
# Preserve legacy testcase behavior.
if self.family == "xunit1":
if self.family == 'xunit1':
return
# Filter out attributes not permitted by this test family.
# Including custom attributes because they are not valid here.
temp_attrs = {}
for key in self.attrs.keys():
if key in families[self.family]["testcase"]:
if key in families[self.family]['testcase']:
temp_attrs[key] = self.attrs[key]
self.attrs = temp_attrs
def to_xml(self) -> ET.Element:
testcase = ET.Element("testcase", self.attrs, time="%.3f" % self.duration)
testcase = ET.Element('testcase', self.attrs, time='%.3f' % self.duration)
properties = self.make_properties_node()
if properties is not None:
testcase.append(properties)
testcase.extend(self.nodes)
return testcase
def _add_simple(self, tag: str, message: str, data: Optional[str] = None) -> None:
def _add_simple(self, tag: str, message: str, data: str | None = None) -> None:
node = ET.Element(tag, message=message)
node.text = bin_xml_escape(data)
self.append(node)
@ -167,24 +169,24 @@ class _NodeReporter:
content_out = report.capstdout
content_log = report.caplog
content_err = report.capstderr
if self.xml.logging == "no":
if self.xml.logging == 'no':
return
content_all = ""
if self.xml.logging in ["log", "all"]:
content_all = self._prepare_content(content_log, " Captured Log ")
if self.xml.logging in ["system-out", "out-err", "all"]:
content_all += self._prepare_content(content_out, " Captured Out ")
self._write_content(report, content_all, "system-out")
content_all = ""
if self.xml.logging in ["system-err", "out-err", "all"]:
content_all += self._prepare_content(content_err, " Captured Err ")
self._write_content(report, content_all, "system-err")
content_all = ""
content_all = ''
if self.xml.logging in ['log', 'all']:
content_all = self._prepare_content(content_log, ' Captured Log ')
if self.xml.logging in ['system-out', 'out-err', 'all']:
content_all += self._prepare_content(content_out, ' Captured Out ')
self._write_content(report, content_all, 'system-out')
content_all = ''
if self.xml.logging in ['system-err', 'out-err', 'all']:
content_all += self._prepare_content(content_err, ' Captured Err ')
self._write_content(report, content_all, 'system-err')
content_all = ''
if content_all:
self._write_content(report, content_all, "system-out")
self._write_content(report, content_all, 'system-out')
def _prepare_content(self, content: str, header: str) -> str:
return "\n".join([header.center(80, "-"), content, ""])
return '\n'.join([header.center(80, '-'), content, ''])
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
tag = ET.Element(jheader)
@ -192,65 +194,65 @@ class _NodeReporter:
self.append(tag)
def append_pass(self, report: TestReport) -> None:
self.add_stats("passed")
self.add_stats('passed')
def append_failure(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
if hasattr(report, "wasxfail"):
self._add_simple("skipped", "xfail-marked test passes unexpectedly")
if hasattr(report, 'wasxfail'):
self._add_simple('skipped', 'xfail-marked test passes unexpectedly')
else:
assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr(
report.longrepr, "reprcrash", None
reprcrash: ReprFileLocation | None = getattr(
report.longrepr, 'reprcrash', None,
)
if reprcrash is not None:
message = reprcrash.message
else:
message = str(report.longrepr)
message = bin_xml_escape(message)
self._add_simple("failure", message, str(report.longrepr))
self._add_simple('failure', message, str(report.longrepr))
def append_collect_error(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
assert report.longrepr is not None
self._add_simple("error", "collection failure", str(report.longrepr))
self._add_simple('error', 'collection failure', str(report.longrepr))
def append_collect_skipped(self, report: TestReport) -> None:
self._add_simple("skipped", "collection skipped", str(report.longrepr))
self._add_simple('skipped', 'collection skipped', str(report.longrepr))
def append_error(self, report: TestReport) -> None:
assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr(
report.longrepr, "reprcrash", None
reprcrash: ReprFileLocation | None = getattr(
report.longrepr, 'reprcrash', None,
)
if reprcrash is not None:
reason = reprcrash.message
else:
reason = str(report.longrepr)
if report.when == "teardown":
if report.when == 'teardown':
msg = f'failed on teardown with "{reason}"'
else:
msg = f'failed on setup with "{reason}"'
self._add_simple("error", bin_xml_escape(msg), str(report.longrepr))
self._add_simple('error', bin_xml_escape(msg), str(report.longrepr))
def append_skipped(self, report: TestReport) -> None:
if hasattr(report, "wasxfail"):
if hasattr(report, 'wasxfail'):
xfailreason = report.wasxfail
if xfailreason.startswith("reason: "):
if xfailreason.startswith('reason: '):
xfailreason = xfailreason[8:]
xfailreason = bin_xml_escape(xfailreason)
skipped = ET.Element("skipped", type="pytest.xfail", message=xfailreason)
skipped = ET.Element('skipped', type='pytest.xfail', message=xfailreason)
self.append(skipped)
else:
assert isinstance(report.longrepr, tuple)
filename, lineno, skipreason = report.longrepr
if skipreason.startswith("Skipped: "):
if skipreason.startswith('Skipped: '):
skipreason = skipreason[9:]
details = f"{filename}:{lineno}: {skipreason}"
details = f'{filename}:{lineno}: {skipreason}'
skipped = ET.Element(
"skipped", type="pytest.skip", message=bin_xml_escape(skipreason)
'skipped', type='pytest.skip', message=bin_xml_escape(skipreason),
)
skipped.text = bin_xml_escape(details)
self.append(skipped)
@ -265,17 +267,17 @@ class _NodeReporter:
def _warn_incompatibility_with_xunit2(
request: FixtureRequest, fixture_name: str
request: FixtureRequest, fixture_name: str,
) -> None:
"""Emit a PytestWarning about the given fixture being incompatible with newer xunit revisions."""
from _pytest.warning_types import PytestWarning
xml = request.config.stash.get(xml_key, None)
if xml is not None and xml.family not in ("xunit1", "legacy"):
if xml is not None and xml.family not in ('xunit1', 'legacy'):
request.node.warn(
PytestWarning(
f"{fixture_name} is incompatible with junit_family '{xml.family}' (use 'legacy' or 'xunit1')"
)
f"{fixture_name} is incompatible with junit_family '{xml.family}' (use 'legacy' or 'xunit1')",
),
)
@ -294,7 +296,7 @@ def record_property(request: FixtureRequest) -> Callable[[str, object], None]:
def test_function(record_property):
record_property("example_key", 1)
"""
_warn_incompatibility_with_xunit2(request, "record_property")
_warn_incompatibility_with_xunit2(request, 'record_property')
def append_property(name: str, value: object) -> None:
request.node.user_properties.append((name, value))
@ -312,10 +314,10 @@ def record_xml_attribute(request: FixtureRequest) -> Callable[[str, object], Non
from _pytest.warning_types import PytestExperimentalApiWarning
request.node.warn(
PytestExperimentalApiWarning("record_xml_attribute is an experimental feature")
PytestExperimentalApiWarning('record_xml_attribute is an experimental feature'),
)
_warn_incompatibility_with_xunit2(request, "record_xml_attribute")
_warn_incompatibility_with_xunit2(request, 'record_xml_attribute')
# Declare noop
def add_attr_noop(name: str, value: object) -> None:
@ -336,11 +338,11 @@ def _check_record_param_type(param: str, v: str) -> None:
type."""
__tracebackhide__ = True
if not isinstance(v, str):
msg = "{param} parameter needs to be a string, but {g} given" # type: ignore[unreachable]
msg = '{param} parameter needs to be a string, but {g} given' # type: ignore[unreachable]
raise TypeError(msg.format(param=param, g=type(v).__name__))
@pytest.fixture(scope="session")
@pytest.fixture(scope='session')
def record_testsuite_property(request: FixtureRequest) -> Callable[[str, object], None]:
"""Record a new ``<property>`` tag as child of the root ``<testsuite>``.
@ -371,7 +373,7 @@ def record_testsuite_property(request: FixtureRequest) -> Callable[[str, object]
def record_func(name: str, value: object) -> None:
"""No-op function in case --junit-xml was not passed in the command-line."""
__tracebackhide__ = True
_check_record_param_type("name", name)
_check_record_param_type('name', name)
xml = request.config.stash.get(xml_key, None)
if xml is not None:
@ -380,65 +382,65 @@ def record_testsuite_property(request: FixtureRequest) -> Callable[[str, object]
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group = parser.getgroup('terminal reporting')
group.addoption(
"--junitxml",
"--junit-xml",
action="store",
dest="xmlpath",
metavar="path",
type=functools.partial(filename_arg, optname="--junitxml"),
'--junitxml',
'--junit-xml',
action='store',
dest='xmlpath',
metavar='path',
type=functools.partial(filename_arg, optname='--junitxml'),
default=None,
help="Create junit-xml style report file at given path",
help='Create junit-xml style report file at given path',
)
group.addoption(
"--junitprefix",
"--junit-prefix",
action="store",
metavar="str",
'--junitprefix',
'--junit-prefix',
action='store',
metavar='str',
default=None,
help="Prepend prefix to classnames in junit-xml output",
help='Prepend prefix to classnames in junit-xml output',
)
parser.addini(
"junit_suite_name", "Test suite name for JUnit report", default="pytest"
'junit_suite_name', 'Test suite name for JUnit report', default='pytest',
)
parser.addini(
"junit_logging",
"Write captured log messages to JUnit report: "
"one of no|log|system-out|system-err|out-err|all",
default="no",
'junit_logging',
'Write captured log messages to JUnit report: '
'one of no|log|system-out|system-err|out-err|all',
default='no',
)
parser.addini(
"junit_log_passing_tests",
"Capture log information for passing tests to JUnit report: ",
type="bool",
'junit_log_passing_tests',
'Capture log information for passing tests to JUnit report: ',
type='bool',
default=True,
)
parser.addini(
"junit_duration_report",
"Duration time to report: one of total|call",
default="total",
'junit_duration_report',
'Duration time to report: one of total|call',
default='total',
) # choices=['total', 'call'])
parser.addini(
"junit_family",
"Emit XML for schema: one of legacy|xunit1|xunit2",
default="xunit2",
'junit_family',
'Emit XML for schema: one of legacy|xunit1|xunit2',
default='xunit2',
)
def pytest_configure(config: Config) -> None:
xmlpath = config.option.xmlpath
# Prevent opening xmllog on worker nodes (xdist).
if xmlpath and not hasattr(config, "workerinput"):
junit_family = config.getini("junit_family")
if xmlpath and not hasattr(config, 'workerinput'):
junit_family = config.getini('junit_family')
config.stash[xml_key] = LogXML(
xmlpath,
config.option.junitprefix,
config.getini("junit_suite_name"),
config.getini("junit_logging"),
config.getini("junit_duration_report"),
config.getini('junit_suite_name'),
config.getini('junit_logging'),
config.getini('junit_duration_report'),
junit_family,
config.getini("junit_log_passing_tests"),
config.getini('junit_log_passing_tests'),
)
config.pluginmanager.register(config.stash[xml_key])
@ -450,12 +452,12 @@ def pytest_unconfigure(config: Config) -> None:
config.pluginmanager.unregister(xml)
def mangle_test_address(address: str) -> List[str]:
path, possible_open_bracket, params = address.partition("[")
names = path.split("::")
def mangle_test_address(address: str) -> list[str]:
path, possible_open_bracket, params = address.partition('[')
names = path.split('::')
# Convert file path to dotted path.
names[0] = names[0].replace(nodes.SEP, ".")
names[0] = re.sub(r"\.py$", "", names[0])
names[0] = names[0].replace(nodes.SEP, '.')
names[0] = re.sub(r'\.py$', '', names[0])
# Put any params back.
names[-1] += possible_open_bracket + params
return names
@ -465,11 +467,11 @@ class LogXML:
def __init__(
self,
logfile,
prefix: Optional[str],
suite_name: str = "pytest",
logging: str = "no",
report_duration: str = "total",
family="xunit1",
prefix: str | None,
suite_name: str = 'pytest',
logging: str = 'no',
report_duration: str = 'total',
family='xunit1',
log_passing_tests: bool = True,
) -> None:
logfile = os.path.expanduser(os.path.expandvars(logfile))
@ -480,27 +482,27 @@ class LogXML:
self.log_passing_tests = log_passing_tests
self.report_duration = report_duration
self.family = family
self.stats: Dict[str, int] = dict.fromkeys(
["error", "passed", "failure", "skipped"], 0
self.stats: dict[str, int] = dict.fromkeys(
['error', 'passed', 'failure', 'skipped'], 0,
)
self.node_reporters: Dict[
Tuple[Union[str, TestReport], object], _NodeReporter
self.node_reporters: dict[
tuple[str | TestReport, object], _NodeReporter,
] = {}
self.node_reporters_ordered: List[_NodeReporter] = []
self.global_properties: List[Tuple[str, str]] = []
self.node_reporters_ordered: list[_NodeReporter] = []
self.global_properties: list[tuple[str, str]] = []
# List of reports that failed on call but teardown is pending.
self.open_reports: List[TestReport] = []
self.open_reports: list[TestReport] = []
self.cnt_double_fail_tests = 0
# Replaces convenience family with real family.
if self.family == "legacy":
self.family = "xunit1"
if self.family == 'legacy':
self.family = 'xunit1'
def finalize(self, report: TestReport) -> None:
nodeid = getattr(report, "nodeid", report)
nodeid = getattr(report, 'nodeid', report)
# Local hack to handle xdist report order.
workernode = getattr(report, "node", None)
workernode = getattr(report, 'node', None)
reporter = self.node_reporters.pop((nodeid, workernode))
for propname, propvalue in report.user_properties:
@ -509,10 +511,10 @@ class LogXML:
if reporter is not None:
reporter.finalize()
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporter:
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
def node_reporter(self, report: TestReport | str) -> _NodeReporter:
nodeid: str | TestReport = getattr(report, 'nodeid', report)
# Local hack to handle xdist report order.
workernode = getattr(report, "node", None)
workernode = getattr(report, 'node', None)
key = nodeid, workernode
@ -561,22 +563,22 @@ class LogXML:
"""
close_report = None
if report.passed:
if report.when == "call": # ignore setup/teardown
if report.when == 'call': # ignore setup/teardown
reporter = self._opentestcase(report)
reporter.append_pass(report)
elif report.failed:
if report.when == "teardown":
if report.when == 'teardown':
# The following vars are needed when xdist plugin is used.
report_wid = getattr(report, "worker_id", None)
report_ii = getattr(report, "item_index", None)
report_wid = getattr(report, 'worker_id', None)
report_ii = getattr(report, 'item_index', None)
close_report = next(
(
rep
for rep in self.open_reports
if (
rep.nodeid == report.nodeid
and getattr(rep, "item_index", None) == report_ii
and getattr(rep, "worker_id", None) == report_wid
rep.nodeid == report.nodeid and
getattr(rep, 'item_index', None) == report_ii and
getattr(rep, 'worker_id', None) == report_wid
)
),
None,
@ -588,7 +590,7 @@ class LogXML:
self.finalize(close_report)
self.cnt_double_fail_tests += 1
reporter = self._opentestcase(report)
if report.when == "call":
if report.when == 'call':
reporter.append_failure(report)
self.open_reports.append(report)
if not self.log_passing_tests:
@ -599,21 +601,21 @@ class LogXML:
reporter = self._opentestcase(report)
reporter.append_skipped(report)
self.update_testcase_duration(report)
if report.when == "teardown":
if report.when == 'teardown':
reporter = self._opentestcase(report)
reporter.write_captured_output(report)
self.finalize(report)
report_wid = getattr(report, "worker_id", None)
report_ii = getattr(report, "item_index", None)
report_wid = getattr(report, 'worker_id', None)
report_ii = getattr(report, 'item_index', None)
close_report = next(
(
rep
for rep in self.open_reports
if (
rep.nodeid == report.nodeid
and getattr(rep, "item_index", None) == report_ii
and getattr(rep, "worker_id", None) == report_wid
rep.nodeid == report.nodeid and
getattr(rep, 'item_index', None) == report_ii and
getattr(rep, 'worker_id', None) == report_wid
)
),
None,
@ -624,9 +626,9 @@ class LogXML:
def update_testcase_duration(self, report: TestReport) -> None:
"""Accumulate total duration for nodeid from given report and update
the Junit.testcase with the new total if already created."""
if self.report_duration in {"total", report.when}:
if self.report_duration in {'total', report.when}:
reporter = self.node_reporter(report)
reporter.duration += getattr(report, "duration", 0.0)
reporter.duration += getattr(report, 'duration', 0.0)
def pytest_collectreport(self, report: TestReport) -> None:
if not report.passed:
@ -637,9 +639,9 @@ class LogXML:
reporter.append_collect_skipped(report)
def pytest_internalerror(self, excrepr: ExceptionRepr) -> None:
reporter = self.node_reporter("internal")
reporter.attrs.update(classname="pytest", name="internal")
reporter._add_simple("error", "internal error", str(excrepr))
reporter = self.node_reporter('internal')
reporter.attrs.update(classname='pytest', name='internal')
reporter._add_simple('error', 'internal error', str(excrepr))
def pytest_sessionstart(self) -> None:
self.suite_start_time = timing.time()
@ -649,27 +651,27 @@ class LogXML:
# exist_ok avoids filesystem race conditions between checking path existence and requesting creation
os.makedirs(dirname, exist_ok=True)
with open(self.logfile, "w", encoding="utf-8") as logfile:
with open(self.logfile, 'w', encoding='utf-8') as logfile:
suite_stop_time = timing.time()
suite_time_delta = suite_stop_time - self.suite_start_time
numtests = (
self.stats["passed"]
+ self.stats["failure"]
+ self.stats["skipped"]
+ self.stats["error"]
- self.cnt_double_fail_tests
self.stats['passed'] +
self.stats['failure'] +
self.stats['skipped'] +
self.stats['error'] -
self.cnt_double_fail_tests
)
logfile.write('<?xml version="1.0" encoding="utf-8"?>')
suite_node = ET.Element(
"testsuite",
'testsuite',
name=self.suite_name,
errors=str(self.stats["error"]),
failures=str(self.stats["failure"]),
skipped=str(self.stats["skipped"]),
errors=str(self.stats['error']),
failures=str(self.stats['failure']),
skipped=str(self.stats['skipped']),
tests=str(numtests),
time="%.3f" % suite_time_delta,
time='%.3f' % suite_time_delta,
timestamp=datetime.fromtimestamp(self.suite_start_time).isoformat(),
hostname=platform.node(),
)
@ -678,23 +680,23 @@ class LogXML:
suite_node.append(global_properties)
for node_reporter in self.node_reporters_ordered:
suite_node.append(node_reporter.to_xml())
testsuites = ET.Element("testsuites")
testsuites = ET.Element('testsuites')
testsuites.append(suite_node)
logfile.write(ET.tostring(testsuites, encoding="unicode"))
logfile.write(ET.tostring(testsuites, encoding='unicode'))
def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None:
terminalreporter.write_sep("-", f"generated xml file: {self.logfile}")
terminalreporter.write_sep('-', f'generated xml file: {self.logfile}')
def add_global_property(self, name: str, value: object) -> None:
__tracebackhide__ = True
_check_record_param_type("name", name)
_check_record_param_type('name', name)
self.global_properties.append((name, bin_xml_escape(value)))
def _get_global_properties_node(self) -> Optional[ET.Element]:
def _get_global_properties_node(self) -> ET.Element | None:
"""Return a Junit node containing custom properties, if any."""
if self.global_properties:
properties = ET.Element("properties")
properties = ET.Element('properties')
for name, value in self.global_properties:
properties.append(ET.Element("property", name=name, value=value))
properties.append(ET.Element('property', name=name, value=value))
return properties
return None

View file

@ -1,10 +1,11 @@
# mypy: allow-untyped-defs
"""Add backward compatibility support for the legacy py path type."""
from __future__ import annotations
import dataclasses
from pathlib import Path
import shlex
import subprocess
from pathlib import Path
from typing import Final
from typing import final
from typing import List
@ -12,8 +13,6 @@ from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from iniconfig import SectionWrapper
from _pytest.cacheprovider import Cache
from _pytest.compat import LEGACY_PATH
from _pytest.compat import legacy_path
@ -33,6 +32,7 @@ from _pytest.pytester import Pytester
from _pytest.pytester import RunResult
from _pytest.terminal import TerminalReporter
from _pytest.tmpdir import TempPathFactory
from iniconfig import SectionWrapper
if TYPE_CHECKING:
@ -50,8 +50,8 @@ class Testdir:
__test__ = False
CLOSE_STDIN: "Final" = Pytester.CLOSE_STDIN
TimeoutExpired: "Final" = Pytester.TimeoutExpired
CLOSE_STDIN: Final = Pytester.CLOSE_STDIN
TimeoutExpired: Final = Pytester.TimeoutExpired
def __init__(self, pytester: Pytester, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
@ -95,14 +95,14 @@ class Testdir:
def makefile(self, ext, *args, **kwargs) -> LEGACY_PATH:
"""See :meth:`Pytester.makefile`."""
if ext and not ext.startswith("."):
if ext and not ext.startswith('.'):
# pytester.makefile is going to throw a ValueError in a way that
# testdir.makefile did not, because
# pathlib.Path is stricter suffixes than py.path
# This ext arguments is likely user error, but since testdir has
# allowed this, we will prepend "." as a workaround to avoid breaking
# testdir usage that worked before
ext = "." + ext
ext = '.' + ext
return legacy_path(self._pytester.makefile(ext, *args, **kwargs))
def makeconftest(self, source) -> LEGACY_PATH:
@ -145,7 +145,7 @@ class Testdir:
"""See :meth:`Pytester.copy_example`."""
return legacy_path(self._pytester.copy_example(name))
def getnode(self, config: Config, arg) -> Optional[Union[Item, Collector]]:
def getnode(self, config: Config, arg) -> Item | Collector | None:
"""See :meth:`Pytester.getnode`."""
return self._pytester.getnode(config, arg)
@ -153,7 +153,7 @@ class Testdir:
"""See :meth:`Pytester.getpathnode`."""
return self._pytester.getpathnode(path)
def genitems(self, colitems: List[Union[Item, Collector]]) -> List[Item]:
def genitems(self, colitems: list[Item | Collector]) -> list[Item]:
"""See :meth:`Pytester.genitems`."""
return self._pytester.genitems(colitems)
@ -172,7 +172,7 @@ class Testdir:
def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False):
"""See :meth:`Pytester.inline_run`."""
return self._pytester.inline_run(
*args, plugins=plugins, no_reraise_ctrlc=no_reraise_ctrlc
*args, plugins=plugins, no_reraise_ctrlc=no_reraise_ctrlc,
)
def runpytest_inprocess(self, *args, **kwargs) -> RunResult:
@ -191,7 +191,7 @@ class Testdir:
"""See :meth:`Pytester.parseconfigure`."""
return self._pytester.parseconfigure(*args)
def getitem(self, source, funcname="test_func"):
def getitem(self, source, funcname='test_func'):
"""See :meth:`Pytester.getitem`."""
return self._pytester.getitem(source, funcname)
@ -202,12 +202,12 @@ class Testdir:
def getmodulecol(self, source, configargs=(), withinit=False):
"""See :meth:`Pytester.getmodulecol`."""
return self._pytester.getmodulecol(
source, configargs=configargs, withinit=withinit
source, configargs=configargs, withinit=withinit,
)
def collect_by_name(
self, modcol: Collector, name: str
) -> Optional[Union[Item, Collector]]:
self, modcol: Collector, name: str,
) -> Item | Collector | None:
"""See :meth:`Pytester.collect_by_name`."""
return self._pytester.collect_by_name(modcol, name)
@ -239,17 +239,17 @@ class Testdir:
return self._pytester.runpytest_subprocess(*args, timeout=timeout)
def spawn_pytest(
self, string: str, expect_timeout: float = 10.0
) -> "pexpect.spawn":
self, string: str, expect_timeout: float = 10.0,
) -> pexpect.spawn:
"""See :meth:`Pytester.spawn_pytest`."""
return self._pytester.spawn_pytest(string, expect_timeout=expect_timeout)
def spawn(self, cmd: str, expect_timeout: float = 10.0) -> "pexpect.spawn":
def spawn(self, cmd: str, expect_timeout: float = 10.0) -> pexpect.spawn:
"""See :meth:`Pytester.spawn`."""
return self._pytester.spawn(cmd, expect_timeout=expect_timeout)
def __repr__(self) -> str:
return f"<Testdir {self.tmpdir!r}>"
return f'<Testdir {self.tmpdir!r}>'
def __str__(self) -> str:
return str(self.tmpdir)
@ -284,7 +284,7 @@ class TempdirFactory:
_tmppath_factory: TempPathFactory
def __init__(
self, tmppath_factory: TempPathFactory, *, _ispytest: bool = False
self, tmppath_factory: TempPathFactory, *, _ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self._tmppath_factory = tmppath_factory
@ -300,7 +300,7 @@ class TempdirFactory:
class LegacyTmpdirPlugin:
@staticmethod
@fixture(scope="session")
@fixture(scope='session')
def tmpdir_factory(request: FixtureRequest) -> TempdirFactory:
"""Return a :class:`pytest.TempdirFactory` instance for the test session."""
# Set dynamically by pytest_configure().
@ -374,7 +374,7 @@ def Config_rootdir(self: Config) -> LEGACY_PATH:
return legacy_path(str(self.rootpath))
def Config_inifile(self: Config) -> Optional[LEGACY_PATH]:
def Config_inifile(self: Config) -> LEGACY_PATH | None:
"""The path to the :ref:`configfile <configfiles>`.
Prefer to use :attr:`inipath`, which is a :class:`pathlib.Path`.
@ -395,16 +395,16 @@ def Session_stardir(self: Session) -> LEGACY_PATH:
def Config__getini_unknown_type(
self, name: str, type: str, value: Union[str, List[str]]
self, name: str, type: str, value: str | list[str],
):
if type == "pathlist":
if type == 'pathlist':
# TODO: This assert is probably not valid in all cases.
assert self.inipath is not None
dp = self.inipath.parent
input_values = shlex.split(value) if isinstance(value, str) else value
return [legacy_path(str(dp / x)) for x in input_values]
else:
raise ValueError(f"unknown configuration type: {type}", value)
raise ValueError(f'unknown configuration type: {type}', value)
def Node_fspath(self: Node) -> LEGACY_PATH:
@ -423,35 +423,35 @@ def pytest_load_initial_conftests(early_config: Config) -> None:
early_config.add_cleanup(mp.undo)
# Add Cache.makedir().
mp.setattr(Cache, "makedir", Cache_makedir, raising=False)
mp.setattr(Cache, 'makedir', Cache_makedir, raising=False)
# Add FixtureRequest.fspath property.
mp.setattr(FixtureRequest, "fspath", property(FixtureRequest_fspath), raising=False)
mp.setattr(FixtureRequest, 'fspath', property(FixtureRequest_fspath), raising=False)
# Add TerminalReporter.startdir property.
mp.setattr(
TerminalReporter, "startdir", property(TerminalReporter_startdir), raising=False
TerminalReporter, 'startdir', property(TerminalReporter_startdir), raising=False,
)
# Add Config.{invocation_dir,rootdir,inifile} properties.
mp.setattr(Config, "invocation_dir", property(Config_invocation_dir), raising=False)
mp.setattr(Config, "rootdir", property(Config_rootdir), raising=False)
mp.setattr(Config, "inifile", property(Config_inifile), raising=False)
mp.setattr(Config, 'invocation_dir', property(Config_invocation_dir), raising=False)
mp.setattr(Config, 'rootdir', property(Config_rootdir), raising=False)
mp.setattr(Config, 'inifile', property(Config_inifile), raising=False)
# Add Session.startdir property.
mp.setattr(Session, "startdir", property(Session_stardir), raising=False)
mp.setattr(Session, 'startdir', property(Session_stardir), raising=False)
# Add pathlist configuration type.
mp.setattr(Config, "_getini_unknown_type", Config__getini_unknown_type)
mp.setattr(Config, '_getini_unknown_type', Config__getini_unknown_type)
# Add Node.fspath property.
mp.setattr(Node, "fspath", property(Node_fspath, Node_fspath_set), raising=False)
mp.setattr(Node, 'fspath', property(Node_fspath, Node_fspath_set), raising=False)
@hookimpl
def pytest_configure(config: Config) -> None:
"""Installs the LegacyTmpdirPlugin if the ``tmpdir`` plugin is also installed."""
if config.pluginmanager.has_plugin("tmpdir"):
if config.pluginmanager.has_plugin('tmpdir'):
mp = MonkeyPatch()
config.add_cleanup(mp.undo)
# Create TmpdirFactory and attach it to the config object.
@ -466,15 +466,15 @@ def pytest_configure(config: Config) -> None:
pass
else:
_tmpdirhandler = TempdirFactory(tmp_path_factory, _ispytest=True)
mp.setattr(config, "_tmpdirhandler", _tmpdirhandler, raising=False)
mp.setattr(config, '_tmpdirhandler', _tmpdirhandler, raising=False)
config.pluginmanager.register(LegacyTmpdirPlugin, "legacypath-tmpdir")
config.pluginmanager.register(LegacyTmpdirPlugin, 'legacypath-tmpdir')
@hookimpl
def pytest_plugin_registered(plugin: object, manager: PytestPluginManager) -> None:
# pytester is not loaded by default and is commonly loaded from a conftest,
# so checking for it in `pytest_configure` is not enough.
is_pytester = plugin is manager.get_plugin("pytester")
is_pytester = plugin is manager.get_plugin('pytester')
if is_pytester and not manager.is_registered(LegacyTestdirPlugin):
manager.register(LegacyTestdirPlugin, "legacypath-pytester")
manager.register(LegacyTestdirPlugin, 'legacypath-pytester')

View file

@ -1,17 +1,19 @@
# mypy: allow-untyped-defs
"""Access and control log capturing."""
from __future__ import annotations
import io
import logging
import os
import re
from contextlib import contextmanager
from contextlib import nullcontext
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import io
from io import StringIO
import logging
from logging import LogRecord
import os
from pathlib import Path
import re
from types import TracebackType
from typing import AbstractSet
from typing import Dict
@ -50,15 +52,15 @@ if TYPE_CHECKING:
else:
logging_StreamHandler = logging.StreamHandler
DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s"
DEFAULT_LOG_DATE_FORMAT = "%H:%M:%S"
_ANSI_ESCAPE_SEQ = re.compile(r"\x1b\[[\d;]+m")
caplog_handler_key = StashKey["LogCaptureHandler"]()
DEFAULT_LOG_FORMAT = '%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s'
DEFAULT_LOG_DATE_FORMAT = '%H:%M:%S'
_ANSI_ESCAPE_SEQ = re.compile(r'\x1b\[[\d;]+m')
caplog_handler_key = StashKey['LogCaptureHandler']()
caplog_records_key = StashKey[Dict[str, List[logging.LogRecord]]]()
def _remove_ansi_escape_sequences(text: str) -> str:
return _ANSI_ESCAPE_SEQ.sub("", text)
return _ANSI_ESCAPE_SEQ.sub('', text)
class DatetimeFormatter(logging.Formatter):
@ -67,8 +69,8 @@ class DatetimeFormatter(logging.Formatter):
:func:`time.strftime` in case of microseconds in format string.
"""
def formatTime(self, record: LogRecord, datefmt: Optional[str] = None) -> str:
if datefmt and "%f" in datefmt:
def formatTime(self, record: LogRecord, datefmt: str | None = None) -> str:
if datefmt and '%f' in datefmt:
ct = self.converter(record.created)
tz = timezone(timedelta(seconds=ct.tm_gmtoff), ct.tm_zone)
# Construct `datetime.datetime` object from `struct_time`
@ -85,21 +87,21 @@ class ColoredLevelFormatter(DatetimeFormatter):
log format passed to __init__."""
LOGLEVEL_COLOROPTS: Mapping[int, AbstractSet[str]] = {
logging.CRITICAL: {"red"},
logging.ERROR: {"red", "bold"},
logging.WARNING: {"yellow"},
logging.WARN: {"yellow"},
logging.INFO: {"green"},
logging.DEBUG: {"purple"},
logging.CRITICAL: {'red'},
logging.ERROR: {'red', 'bold'},
logging.WARNING: {'yellow'},
logging.WARN: {'yellow'},
logging.INFO: {'green'},
logging.DEBUG: {'purple'},
logging.NOTSET: set(),
}
LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*(?:\.\d+)?s)")
LEVELNAME_FMT_REGEX = re.compile(r'%\(levelname\)([+-.]?\d*(?:\.\d+)?s)')
def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._terminalwriter = terminalwriter
self._original_fmt = self._style._fmt
self._level_to_fmt_mapping: Dict[int, str] = {}
self._level_to_fmt_mapping: dict[int, str] = {}
for level, color_opts in self.LOGLEVEL_COLOROPTS.items():
self.add_color_level(level, *color_opts)
@ -123,15 +125,15 @@ class ColoredLevelFormatter(DatetimeFormatter):
return
levelname_fmt = levelname_fmt_match.group()
formatted_levelname = levelname_fmt % {"levelname": logging.getLevelName(level)}
formatted_levelname = levelname_fmt % {'levelname': logging.getLevelName(level)}
# add ANSI escape sequences around the formatted levelname
color_kwargs = {name: True for name in color_opts}
colorized_formatted_levelname = self._terminalwriter.markup(
formatted_levelname, **color_kwargs
formatted_levelname, **color_kwargs,
)
self._level_to_fmt_mapping[level] = self.LEVELNAME_FMT_REGEX.sub(
colorized_formatted_levelname, self._fmt
colorized_formatted_levelname, self._fmt,
)
def format(self, record: logging.LogRecord) -> str:
@ -147,12 +149,12 @@ class PercentStyleMultiline(logging.PercentStyle):
formats the message as if each line were logged separately.
"""
def __init__(self, fmt: str, auto_indent: Union[int, str, bool, None]) -> None:
def __init__(self, fmt: str, auto_indent: int | str | bool | None) -> None:
super().__init__(fmt)
self._auto_indent = self._get_auto_indent(auto_indent)
@staticmethod
def _get_auto_indent(auto_indent_option: Union[int, str, bool, None]) -> int:
def _get_auto_indent(auto_indent_option: int | str | bool | None) -> int:
"""Determine the current auto indentation setting.
Specify auto indent behavior (on/off/fixed) by passing in
@ -206,8 +208,8 @@ class PercentStyleMultiline(logging.PercentStyle):
return 0
def format(self, record: logging.LogRecord) -> str:
if "\n" in record.message:
if hasattr(record, "auto_indent"):
if '\n' in record.message:
if hasattr(record, 'auto_indent'):
# Passed in from the "extra={}" kwarg on the call to logging.log().
auto_indent = self._get_auto_indent(record.auto_indent) # type: ignore[attr-defined]
else:
@ -215,17 +217,17 @@ class PercentStyleMultiline(logging.PercentStyle):
if auto_indent:
lines = record.message.splitlines()
formatted = self._fmt % {**record.__dict__, "message": lines[0]}
formatted = self._fmt % {**record.__dict__, 'message': lines[0]}
if auto_indent < 0:
indentation = _remove_ansi_escape_sequences(formatted).find(
lines[0]
lines[0],
)
else:
# Optimizes logging by allowing a fixed indentation.
indentation = auto_indent
lines[0] = formatted
return ("\n" + " " * indentation).join(lines)
return ('\n' + ' ' * indentation).join(lines)
return self._fmt % record.__dict__
@ -240,114 +242,114 @@ def get_option_ini(config: Config, *names: str):
def pytest_addoption(parser: Parser) -> None:
"""Add options to control log capturing."""
group = parser.getgroup("logging")
group = parser.getgroup('logging')
def add_option_ini(option, dest, default=None, type=None, **kwargs):
parser.addini(
dest, default=default, type=type, help="Default value for " + option
dest, default=default, type=type, help='Default value for ' + option,
)
group.addoption(option, dest=dest, **kwargs)
add_option_ini(
"--log-level",
dest="log_level",
'--log-level',
dest='log_level',
default=None,
metavar="LEVEL",
metavar='LEVEL',
help=(
"Level of messages to catch/display."
'Level of messages to catch/display.'
" Not set by default, so it depends on the root/parent log handler's"
' effective level, where it is "WARNING" by default.'
),
)
add_option_ini(
"--log-format",
dest="log_format",
'--log-format',
dest='log_format',
default=DEFAULT_LOG_FORMAT,
help="Log format used by the logging module",
help='Log format used by the logging module',
)
add_option_ini(
"--log-date-format",
dest="log_date_format",
'--log-date-format',
dest='log_date_format',
default=DEFAULT_LOG_DATE_FORMAT,
help="Log date format used by the logging module",
help='Log date format used by the logging module',
)
parser.addini(
"log_cli",
'log_cli',
default=False,
type="bool",
type='bool',
help='Enable log display during test run (also known as "live logging")',
)
add_option_ini(
"--log-cli-level", dest="log_cli_level", default=None, help="CLI logging level"
'--log-cli-level', dest='log_cli_level', default=None, help='CLI logging level',
)
add_option_ini(
"--log-cli-format",
dest="log_cli_format",
'--log-cli-format',
dest='log_cli_format',
default=None,
help="Log format used by the logging module",
help='Log format used by the logging module',
)
add_option_ini(
"--log-cli-date-format",
dest="log_cli_date_format",
'--log-cli-date-format',
dest='log_cli_date_format',
default=None,
help="Log date format used by the logging module",
help='Log date format used by the logging module',
)
add_option_ini(
"--log-file",
dest="log_file",
'--log-file',
dest='log_file',
default=None,
help="Path to a file when logging will be written to",
help='Path to a file when logging will be written to',
)
add_option_ini(
"--log-file-mode",
dest="log_file_mode",
default="w",
choices=["w", "a"],
help="Log file open mode",
'--log-file-mode',
dest='log_file_mode',
default='w',
choices=['w', 'a'],
help='Log file open mode',
)
add_option_ini(
"--log-file-level",
dest="log_file_level",
'--log-file-level',
dest='log_file_level',
default=None,
help="Log file logging level",
help='Log file logging level',
)
add_option_ini(
"--log-file-format",
dest="log_file_format",
'--log-file-format',
dest='log_file_format',
default=None,
help="Log format used by the logging module",
help='Log format used by the logging module',
)
add_option_ini(
"--log-file-date-format",
dest="log_file_date_format",
'--log-file-date-format',
dest='log_file_date_format',
default=None,
help="Log date format used by the logging module",
help='Log date format used by the logging module',
)
add_option_ini(
"--log-auto-indent",
dest="log_auto_indent",
'--log-auto-indent',
dest='log_auto_indent',
default=None,
help="Auto-indent multiline messages passed to the logging module. Accepts true|on, false|off or an integer.",
help='Auto-indent multiline messages passed to the logging module. Accepts true|on, false|off or an integer.',
)
group.addoption(
"--log-disable",
action="append",
'--log-disable',
action='append',
default=[],
dest="logger_disable",
help="Disable a logger by name. Can be passed multiple times.",
dest='logger_disable',
help='Disable a logger by name. Can be passed multiple times.',
)
_HandlerType = TypeVar("_HandlerType", bound=logging.Handler)
_HandlerType = TypeVar('_HandlerType', bound=logging.Handler)
# Not using @contextmanager for performance reasons.
class catching_logs(Generic[_HandlerType]):
"""Context manager that prepares the whole logging machinery properly."""
__slots__ = ("handler", "level", "orig_level")
__slots__ = ('handler', 'level', 'orig_level')
def __init__(self, handler: _HandlerType, level: Optional[int] = None) -> None:
def __init__(self, handler: _HandlerType, level: int | None = None) -> None:
self.handler = handler
self.level = level
@ -363,9 +365,9 @@ class catching_logs(Generic[_HandlerType]):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
root_logger = logging.getLogger()
if self.level is not None:
@ -379,7 +381,7 @@ class LogCaptureHandler(logging_StreamHandler):
def __init__(self) -> None:
"""Create a new log handler."""
super().__init__(StringIO())
self.records: List[logging.LogRecord] = []
self.records: list[logging.LogRecord] = []
def emit(self, record: logging.LogRecord) -> None:
"""Keep the log records in a list in addition to the log text."""
@ -410,10 +412,10 @@ class LogCaptureFixture:
def __init__(self, item: nodes.Node, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
self._item = item
self._initial_handler_level: Optional[int] = None
self._initial_handler_level: int | None = None
# Dict of log name -> log level.
self._initial_logger_levels: Dict[Optional[str], int] = {}
self._initial_disabled_logging_level: Optional[int] = None
self._initial_logger_levels: dict[str | None, int] = {}
self._initial_disabled_logging_level: int | None = None
def _finalize(self) -> None:
"""Finalize the fixture.
@ -437,8 +439,8 @@ class LogCaptureFixture:
return self._item.stash[caplog_handler_key]
def get_records(
self, when: Literal["setup", "call", "teardown"]
) -> List[logging.LogRecord]:
self, when: Literal['setup', 'call', 'teardown'],
) -> list[logging.LogRecord]:
"""Get the logging records for one of the possible test phases.
:param when:
@ -457,12 +459,12 @@ class LogCaptureFixture:
return _remove_ansi_escape_sequences(self.handler.stream.getvalue())
@property
def records(self) -> List[logging.LogRecord]:
def records(self) -> list[logging.LogRecord]:
"""The list of log records."""
return self.handler.records
@property
def record_tuples(self) -> List[Tuple[str, int, str]]:
def record_tuples(self) -> list[tuple[str, int, str]]:
"""A list of a stripped down version of log records intended
for use in assertion comparison.
@ -473,7 +475,7 @@ class LogCaptureFixture:
return [(r.name, r.levelno, r.getMessage()) for r in self.records]
@property
def messages(self) -> List[str]:
def messages(self) -> list[str]:
"""A list of format-interpolated log messages.
Unlike 'records', which contains the format string and parameters for
@ -496,7 +498,7 @@ class LogCaptureFixture:
self.handler.clear()
def _force_enable_logging(
self, level: Union[int, str], logger_obj: logging.Logger
self, level: int | str, logger_obj: logging.Logger,
) -> int:
"""Enable the desired logging level if the global level was disabled via ``logging.disabled``.
@ -529,7 +531,7 @@ class LogCaptureFixture:
return original_disable_level
def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None:
def set_level(self, level: int | str, logger: str | None = None) -> None:
"""Set the threshold level of a logger for the duration of a test.
Logging messages which are less severe than this level will not be captured.
@ -556,7 +558,7 @@ class LogCaptureFixture:
@contextmanager
def at_level(
self, level: Union[int, str], logger: Optional[str] = None
self, level: int | str, logger: str | None = None,
) -> Generator[None, None, None]:
"""Context manager that sets the level for capturing of logs. After
the end of the 'with' statement the level is restored to its original
@ -614,7 +616,7 @@ def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture, None, None]:
result._finalize()
def get_log_level_for_setting(config: Config, *setting_names: str) -> Optional[int]:
def get_log_level_for_setting(config: Config, *setting_names: str) -> int | None:
for setting_name in setting_names:
log_level = config.getoption(setting_name)
if log_level is None:
@ -633,14 +635,14 @@ def get_log_level_for_setting(config: Config, *setting_names: str) -> Optional[i
raise UsageError(
f"'{log_level}' is not recognized as a logging level name for "
f"'{setting_name}'. Please consider passing the "
"logging level num instead."
'logging level num instead.',
) from e
# run after terminalreporter/capturemanager are configured
@hookimpl(trylast=True)
def pytest_configure(config: Config) -> None:
config.pluginmanager.register(LoggingPlugin(config), "logging-plugin")
config.pluginmanager.register(LoggingPlugin(config), 'logging-plugin')
class LoggingPlugin:
@ -656,11 +658,11 @@ class LoggingPlugin:
# Report logging.
self.formatter = self._create_formatter(
get_option_ini(config, "log_format"),
get_option_ini(config, "log_date_format"),
get_option_ini(config, "log_auto_indent"),
get_option_ini(config, 'log_format'),
get_option_ini(config, 'log_date_format'),
get_option_ini(config, 'log_auto_indent'),
)
self.log_level = get_log_level_for_setting(config, "log_level")
self.log_level = get_log_level_for_setting(config, 'log_level')
self.caplog_handler = LogCaptureHandler()
self.caplog_handler.setFormatter(self.formatter)
self.report_handler = LogCaptureHandler()
@ -668,52 +670,52 @@ class LoggingPlugin:
# File logging.
self.log_file_level = get_log_level_for_setting(
config, "log_file_level", "log_level"
config, 'log_file_level', 'log_level',
)
log_file = get_option_ini(config, "log_file") or os.devnull
log_file = get_option_ini(config, 'log_file') or os.devnull
if log_file != os.devnull:
directory = os.path.dirname(os.path.abspath(log_file))
if not os.path.isdir(directory):
os.makedirs(directory)
self.log_file_mode = get_option_ini(config, "log_file_mode") or "w"
self.log_file_mode = get_option_ini(config, 'log_file_mode') or 'w'
self.log_file_handler = _FileHandler(
log_file, mode=self.log_file_mode, encoding="UTF-8"
log_file, mode=self.log_file_mode, encoding='UTF-8',
)
log_file_format = get_option_ini(config, "log_file_format", "log_format")
log_file_format = get_option_ini(config, 'log_file_format', 'log_format')
log_file_date_format = get_option_ini(
config, "log_file_date_format", "log_date_format"
config, 'log_file_date_format', 'log_date_format',
)
log_file_formatter = DatetimeFormatter(
log_file_format, datefmt=log_file_date_format
log_file_format, datefmt=log_file_date_format,
)
self.log_file_handler.setFormatter(log_file_formatter)
# CLI/live logging.
self.log_cli_level = get_log_level_for_setting(
config, "log_cli_level", "log_level"
config, 'log_cli_level', 'log_level',
)
if self._log_cli_enabled():
terminal_reporter = config.pluginmanager.get_plugin("terminalreporter")
terminal_reporter = config.pluginmanager.get_plugin('terminalreporter')
# Guaranteed by `_log_cli_enabled()`.
assert terminal_reporter is not None
capture_manager = config.pluginmanager.get_plugin("capturemanager")
capture_manager = config.pluginmanager.get_plugin('capturemanager')
# if capturemanager plugin is disabled, live logging still works.
self.log_cli_handler: Union[
_LiveLoggingStreamHandler, _LiveLoggingNullHandler
] = _LiveLoggingStreamHandler(terminal_reporter, capture_manager)
self.log_cli_handler: (
_LiveLoggingStreamHandler | _LiveLoggingNullHandler
) = _LiveLoggingStreamHandler(terminal_reporter, capture_manager)
else:
self.log_cli_handler = _LiveLoggingNullHandler()
log_cli_formatter = self._create_formatter(
get_option_ini(config, "log_cli_format", "log_format"),
get_option_ini(config, "log_cli_date_format", "log_date_format"),
get_option_ini(config, "log_auto_indent"),
get_option_ini(config, 'log_cli_format', 'log_format'),
get_option_ini(config, 'log_cli_date_format', 'log_date_format'),
get_option_ini(config, 'log_auto_indent'),
)
self.log_cli_handler.setFormatter(log_cli_formatter)
self._disable_loggers(loggers_to_disable=config.option.logger_disable)
def _disable_loggers(self, loggers_to_disable: List[str]) -> None:
def _disable_loggers(self, loggers_to_disable: list[str]) -> None:
if not loggers_to_disable:
return
@ -723,18 +725,18 @@ class LoggingPlugin:
def _create_formatter(self, log_format, log_date_format, auto_indent):
# Color option doesn't exist if terminal plugin is disabled.
color = getattr(self._config.option, "color", "no")
if color != "no" and ColoredLevelFormatter.LEVELNAME_FMT_REGEX.search(
log_format
color = getattr(self._config.option, 'color', 'no')
if color != 'no' and ColoredLevelFormatter.LEVELNAME_FMT_REGEX.search(
log_format,
):
formatter: logging.Formatter = ColoredLevelFormatter(
create_terminal_writer(self._config), log_format, log_date_format
create_terminal_writer(self._config), log_format, log_date_format,
)
else:
formatter = DatetimeFormatter(log_format, log_date_format)
formatter._style = PercentStyleMultiline(
formatter._style._fmt, auto_indent=auto_indent
formatter._style._fmt, auto_indent=auto_indent,
)
return formatter
@ -756,7 +758,7 @@ class LoggingPlugin:
fpath.parent.mkdir(exist_ok=True, parents=True)
# https://github.com/python/mypy/issues/11193
stream: io.TextIOWrapper = fpath.open(mode=self.log_file_mode, encoding="UTF-8") # type: ignore[assignment]
stream: io.TextIOWrapper = fpath.open(mode=self.log_file_mode, encoding='UTF-8') # type: ignore[assignment]
old_stream = self.log_file_handler.setStream(stream)
if old_stream:
old_stream.close()
@ -764,12 +766,12 @@ class LoggingPlugin:
def _log_cli_enabled(self) -> bool:
"""Return whether live logging is enabled."""
enabled = self._config.getoption(
"--log-cli-level"
) is not None or self._config.getini("log_cli")
'--log-cli-level',
) is not None or self._config.getini('log_cli')
if not enabled:
return False
terminal_reporter = self._config.pluginmanager.get_plugin("terminalreporter")
terminal_reporter = self._config.pluginmanager.get_plugin('terminalreporter')
if terminal_reporter is None:
# terminal reporter is disabled e.g. by pytest-xdist.
return False
@ -778,7 +780,7 @@ class LoggingPlugin:
@hookimpl(wrapper=True, tryfirst=True)
def pytest_sessionstart(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("sessionstart")
self.log_cli_handler.set_when('sessionstart')
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
@ -786,7 +788,7 @@ class LoggingPlugin:
@hookimpl(wrapper=True, tryfirst=True)
def pytest_collection(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("collection")
self.log_cli_handler.set_when('collection')
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
@ -797,7 +799,7 @@ class LoggingPlugin:
if session.config.option.collectonly:
return (yield)
if self._log_cli_enabled() and self._config.getoption("verbose") < 1:
if self._log_cli_enabled() and self._config.getoption('verbose') < 1:
# The verbose flag is needed to avoid messy test progress output.
self._config.option.verbose = 1
@ -808,11 +810,11 @@ class LoggingPlugin:
@hookimpl
def pytest_runtest_logstart(self) -> None:
self.log_cli_handler.reset()
self.log_cli_handler.set_when("start")
self.log_cli_handler.set_when('start')
@hookimpl
def pytest_runtest_logreport(self) -> None:
self.log_cli_handler.set_when("logreport")
self.log_cli_handler.set_when('logreport')
def _runtest_for(self, item: nodes.Item, when: str) -> Generator[None, None, None]:
"""Implement the internals of the pytest_runtest_xxx() hooks."""
@ -832,39 +834,39 @@ class LoggingPlugin:
yield
finally:
log = report_handler.stream.getvalue().strip()
item.add_report_section(when, "log", log)
item.add_report_section(when, 'log', log)
@hookimpl(wrapper=True)
def pytest_runtest_setup(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("setup")
self.log_cli_handler.set_when('setup')
empty: Dict[str, List[logging.LogRecord]] = {}
empty: dict[str, list[logging.LogRecord]] = {}
item.stash[caplog_records_key] = empty
yield from self._runtest_for(item, "setup")
yield from self._runtest_for(item, 'setup')
@hookimpl(wrapper=True)
def pytest_runtest_call(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("call")
self.log_cli_handler.set_when('call')
yield from self._runtest_for(item, "call")
yield from self._runtest_for(item, 'call')
@hookimpl(wrapper=True)
def pytest_runtest_teardown(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("teardown")
self.log_cli_handler.set_when('teardown')
try:
yield from self._runtest_for(item, "teardown")
yield from self._runtest_for(item, 'teardown')
finally:
del item.stash[caplog_records_key]
del item.stash[caplog_handler_key]
@hookimpl
def pytest_runtest_logfinish(self) -> None:
self.log_cli_handler.set_when("finish")
self.log_cli_handler.set_when('finish')
@hookimpl(wrapper=True, tryfirst=True)
def pytest_sessionfinish(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("sessionfinish")
self.log_cli_handler.set_when('sessionfinish')
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
@ -901,7 +903,7 @@ class _LiveLoggingStreamHandler(logging_StreamHandler):
def __init__(
self,
terminal_reporter: TerminalReporter,
capture_manager: Optional[CaptureManager],
capture_manager: CaptureManager | None,
) -> None:
super().__init__(stream=terminal_reporter) # type: ignore[arg-type]
self.capture_manager = capture_manager
@ -913,11 +915,11 @@ class _LiveLoggingStreamHandler(logging_StreamHandler):
"""Reset the handler; should be called before the start of each test."""
self._first_record_emitted = False
def set_when(self, when: Optional[str]) -> None:
def set_when(self, when: str | None) -> None:
"""Prepare for the given test phase (setup/call/teardown)."""
self._when = when
self._section_name_shown = False
if when == "start":
if when == 'start':
self._test_outcome_written = False
def emit(self, record: logging.LogRecord) -> None:
@ -928,14 +930,14 @@ class _LiveLoggingStreamHandler(logging_StreamHandler):
)
with ctx_manager:
if not self._first_record_emitted:
self.stream.write("\n")
self.stream.write('\n')
self._first_record_emitted = True
elif self._when in ("teardown", "finish"):
elif self._when in ('teardown', 'finish'):
if not self._test_outcome_written:
self._test_outcome_written = True
self.stream.write("\n")
self.stream.write('\n')
if not self._section_name_shown and self._when:
self.stream.section("live log " + self._when, sep="-", bold=True)
self.stream.section('live log ' + self._when, sep='-', bold=True)
self._section_name_shown = True
super().emit(record)

View file

@ -1,14 +1,15 @@
"""Core implementation of the testing process: init, session, runtest loop."""
from __future__ import annotations
import argparse
import dataclasses
import fnmatch
import functools
import importlib
import importlib.util
import os
from pathlib import Path
import sys
import warnings
from pathlib import Path
from typing import AbstractSet
from typing import Callable
from typing import Dict
@ -24,12 +25,10 @@ from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
import warnings
import pluggy
from _pytest import nodes
import _pytest._code
import pluggy
from _pytest import nodes
from _pytest.config import Config
from _pytest.config import directory_arg
from _pytest.config import ExitCode
@ -58,195 +57,195 @@ if TYPE_CHECKING:
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"norecursedirs",
"Directory patterns to avoid for recursion",
type="args",
'norecursedirs',
'Directory patterns to avoid for recursion',
type='args',
default=[
"*.egg",
".*",
"_darcs",
"build",
"CVS",
"dist",
"node_modules",
"venv",
"{arch}",
'*.egg',
'.*',
'_darcs',
'build',
'CVS',
'dist',
'node_modules',
'venv',
'{arch}',
],
)
parser.addini(
"testpaths",
"Directories to search for tests when no files or directories are given on the "
"command line",
type="args",
'testpaths',
'Directories to search for tests when no files or directories are given on the '
'command line',
type='args',
default=[],
)
group = parser.getgroup("general", "Running and selection options")
group = parser.getgroup('general', 'Running and selection options')
group._addoption(
"-x",
"--exitfirst",
action="store_const",
dest="maxfail",
'-x',
'--exitfirst',
action='store_const',
dest='maxfail',
const=1,
help="Exit instantly on first error or failed test",
help='Exit instantly on first error or failed test',
)
group = parser.getgroup("pytest-warnings")
group = parser.getgroup('pytest-warnings')
group.addoption(
"-W",
"--pythonwarnings",
action="append",
help="Set which warnings to report, see -W option of Python itself",
'-W',
'--pythonwarnings',
action='append',
help='Set which warnings to report, see -W option of Python itself',
)
parser.addini(
"filterwarnings",
type="linelist",
help="Each line specifies a pattern for "
"warnings.filterwarnings. "
"Processed after -W/--pythonwarnings.",
'filterwarnings',
type='linelist',
help='Each line specifies a pattern for '
'warnings.filterwarnings. '
'Processed after -W/--pythonwarnings.',
)
group._addoption(
"--maxfail",
metavar="num",
action="store",
'--maxfail',
metavar='num',
action='store',
type=int,
dest="maxfail",
dest='maxfail',
default=0,
help="Exit after first num failures or errors",
help='Exit after first num failures or errors',
)
group._addoption(
"--strict-config",
action="store_true",
help="Any warnings encountered while parsing the `pytest` section of the "
"configuration file raise errors",
'--strict-config',
action='store_true',
help='Any warnings encountered while parsing the `pytest` section of the '
'configuration file raise errors',
)
group._addoption(
"--strict-markers",
action="store_true",
help="Markers not registered in the `markers` section of the configuration "
"file raise errors",
'--strict-markers',
action='store_true',
help='Markers not registered in the `markers` section of the configuration '
'file raise errors',
)
group._addoption(
"--strict",
action="store_true",
help="(Deprecated) alias to --strict-markers",
'--strict',
action='store_true',
help='(Deprecated) alias to --strict-markers',
)
group._addoption(
"-c",
"--config-file",
metavar="FILE",
'-c',
'--config-file',
metavar='FILE',
type=str,
dest="inifilename",
help="Load configuration from `FILE` instead of trying to locate one of the "
"implicit configuration files.",
dest='inifilename',
help='Load configuration from `FILE` instead of trying to locate one of the '
'implicit configuration files.',
)
group._addoption(
"--continue-on-collection-errors",
action="store_true",
'--continue-on-collection-errors',
action='store_true',
default=False,
dest="continue_on_collection_errors",
help="Force test execution even if collection errors occur",
dest='continue_on_collection_errors',
help='Force test execution even if collection errors occur',
)
group._addoption(
"--rootdir",
action="store",
dest="rootdir",
'--rootdir',
action='store',
dest='rootdir',
help="Define root directory for tests. Can be relative path: 'root_dir', './root_dir', "
"'root_dir/another_dir/'; absolute path: '/home/user/root_dir'; path with variables: "
"'$HOME/root_dir'.",
)
group = parser.getgroup("collect", "collection")
group = parser.getgroup('collect', 'collection')
group.addoption(
"--collectonly",
"--collect-only",
"--co",
action="store_true",
'--collectonly',
'--collect-only',
'--co',
action='store_true',
help="Only collect tests, don't execute them",
)
group.addoption(
"--pyargs",
action="store_true",
help="Try to interpret all arguments as Python packages",
'--pyargs',
action='store_true',
help='Try to interpret all arguments as Python packages',
)
group.addoption(
"--ignore",
action="append",
metavar="path",
help="Ignore path during collection (multi-allowed)",
'--ignore',
action='append',
metavar='path',
help='Ignore path during collection (multi-allowed)',
)
group.addoption(
"--ignore-glob",
action="append",
metavar="path",
help="Ignore path pattern during collection (multi-allowed)",
'--ignore-glob',
action='append',
metavar='path',
help='Ignore path pattern during collection (multi-allowed)',
)
group.addoption(
"--deselect",
action="append",
metavar="nodeid_prefix",
help="Deselect item (via node id prefix) during collection (multi-allowed)",
'--deselect',
action='append',
metavar='nodeid_prefix',
help='Deselect item (via node id prefix) during collection (multi-allowed)',
)
group.addoption(
"--confcutdir",
dest="confcutdir",
'--confcutdir',
dest='confcutdir',
default=None,
metavar="dir",
type=functools.partial(directory_arg, optname="--confcutdir"),
metavar='dir',
type=functools.partial(directory_arg, optname='--confcutdir'),
help="Only load conftest.py's relative to specified dir",
)
group.addoption(
"--noconftest",
action="store_true",
dest="noconftest",
'--noconftest',
action='store_true',
dest='noconftest',
default=False,
help="Don't load any conftest.py files",
)
group.addoption(
"--keepduplicates",
"--keep-duplicates",
action="store_true",
dest="keepduplicates",
'--keepduplicates',
'--keep-duplicates',
action='store_true',
dest='keepduplicates',
default=False,
help="Keep duplicate tests",
help='Keep duplicate tests',
)
group.addoption(
"--collect-in-virtualenv",
action="store_true",
dest="collect_in_virtualenv",
'--collect-in-virtualenv',
action='store_true',
dest='collect_in_virtualenv',
default=False,
help="Don't ignore tests in a local virtualenv directory",
)
group.addoption(
"--import-mode",
default="prepend",
choices=["prepend", "append", "importlib"],
dest="importmode",
help="Prepend/append to sys.path when importing test modules and conftest "
"files. Default: prepend.",
'--import-mode',
default='prepend',
choices=['prepend', 'append', 'importlib'],
dest='importmode',
help='Prepend/append to sys.path when importing test modules and conftest '
'files. Default: prepend.',
)
parser.addini(
"consider_namespace_packages",
type="bool",
'consider_namespace_packages',
type='bool',
default=False,
help="Consider namespace packages when resolving module names during import",
help='Consider namespace packages when resolving module names during import',
)
group = parser.getgroup("debugconfig", "test session debugging and configuration")
group = parser.getgroup('debugconfig', 'test session debugging and configuration')
group.addoption(
"--basetemp",
dest="basetemp",
'--basetemp',
dest='basetemp',
default=None,
type=validate_basetemp,
metavar="dir",
metavar='dir',
help=(
"Base temporary directory for this test run. "
"(Warning: this directory is removed if it exists.)"
'Base temporary directory for this test run. '
'(Warning: this directory is removed if it exists.)'
),
)
def validate_basetemp(path: str) -> str:
# GH 7119
msg = "basetemp must not be empty, the current working directory or any parent directory of it"
msg = 'basetemp must not be empty, the current working directory or any parent directory of it'
# empty path
if not path:
@ -270,8 +269,8 @@ def validate_basetemp(path: str) -> str:
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
config: Config, doit: Callable[[Config, Session], int | ExitCode | None],
) -> int | ExitCode:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
@ -290,12 +289,12 @@ def wrap_session(
session.exitstatus = ExitCode.TESTS_FAILED
except (KeyboardInterrupt, exit.Exception):
excinfo = _pytest._code.ExceptionInfo.from_current()
exitstatus: Union[int, ExitCode] = ExitCode.INTERRUPTED
exitstatus: int | ExitCode = ExitCode.INTERRUPTED
if isinstance(excinfo.value, exit.Exception):
if excinfo.value.returncode is not None:
exitstatus = excinfo.value.returncode
if initstate < 2:
sys.stderr.write(f"{excinfo.typename}: {excinfo.value.msg}\n")
sys.stderr.write(f'{excinfo.typename}: {excinfo.value.msg}\n')
config.hook.pytest_keyboard_interrupt(excinfo=excinfo)
session.exitstatus = exitstatus
except BaseException:
@ -306,10 +305,10 @@ def wrap_session(
except exit.Exception as exc:
if exc.returncode is not None:
session.exitstatus = exc.returncode
sys.stderr.write(f"{type(exc).__name__}: {exc}\n")
sys.stderr.write(f'{type(exc).__name__}: {exc}\n')
else:
if isinstance(excinfo.value, SystemExit):
sys.stderr.write("mainloop: caught unexpected SystemExit!\n")
sys.stderr.write('mainloop: caught unexpected SystemExit!\n')
finally:
# Explicitly break reference cycle.
@ -318,21 +317,21 @@ def wrap_session(
if initstate >= 2:
try:
config.hook.pytest_sessionfinish(
session=session, exitstatus=session.exitstatus
session=session, exitstatus=session.exitstatus,
)
except exit.Exception as exc:
if exc.returncode is not None:
session.exitstatus = exc.returncode
sys.stderr.write(f"{type(exc).__name__}: {exc}\n")
sys.stderr.write(f'{type(exc).__name__}: {exc}\n')
config._ensure_unconfigure()
return session.exitstatus
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
def pytest_cmdline_main(config: Config) -> int | ExitCode:
return wrap_session(config, _main)
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
def _main(config: Config, session: Session) -> int | ExitCode | None:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
@ -345,15 +344,15 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
return None
def pytest_collection(session: "Session") -> None:
def pytest_collection(session: Session) -> None:
session.perform_collect()
def pytest_runtestloop(session: "Session") -> bool:
def pytest_runtestloop(session: Session) -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
'%d error%s during collection'
% (session.testsfailed, 's' if session.testsfailed != 1 else ''),
)
if session.config.option.collectonly:
@ -372,32 +371,32 @@ def pytest_runtestloop(session: "Session") -> bool:
def _in_venv(path: Path) -> bool:
"""Attempt to detect if ``path`` is the root of a Virtual Environment by
checking for the existence of the appropriate activate script."""
bindir = path.joinpath("Scripts" if sys.platform.startswith("win") else "bin")
bindir = path.joinpath('Scripts' if sys.platform.startswith('win') else 'bin')
try:
if not bindir.is_dir():
return False
except OSError:
return False
activates = (
"activate",
"activate.csh",
"activate.fish",
"Activate",
"Activate.bat",
"Activate.ps1",
'activate',
'activate.csh',
'activate.fish',
'Activate',
'Activate.bat',
'Activate.ps1',
)
return any(fname.name in activates for fname in bindir.iterdir())
def pytest_ignore_collect(collection_path: Path, config: Config) -> Optional[bool]:
if collection_path.name == "__pycache__":
def pytest_ignore_collect(collection_path: Path, config: Config) -> bool | None:
if collection_path.name == '__pycache__':
return True
ignore_paths = config._getconftest_pathlist(
"collect_ignore", path=collection_path.parent
'collect_ignore', path=collection_path.parent,
)
ignore_paths = ignore_paths or []
excludeopt = config.getoption("ignore")
excludeopt = config.getoption('ignore')
if excludeopt:
ignore_paths.extend(absolutepath(x) for x in excludeopt)
@ -405,22 +404,22 @@ def pytest_ignore_collect(collection_path: Path, config: Config) -> Optional[boo
return True
ignore_globs = config._getconftest_pathlist(
"collect_ignore_glob", path=collection_path.parent
'collect_ignore_glob', path=collection_path.parent,
)
ignore_globs = ignore_globs or []
excludeglobopt = config.getoption("ignore_glob")
excludeglobopt = config.getoption('ignore_glob')
if excludeglobopt:
ignore_globs.extend(absolutepath(x) for x in excludeglobopt)
if any(fnmatch.fnmatch(str(collection_path), str(glob)) for glob in ignore_globs):
return True
allow_in_venv = config.getoption("collect_in_virtualenv")
allow_in_venv = config.getoption('collect_in_virtualenv')
if not allow_in_venv and _in_venv(collection_path):
return True
if collection_path.is_dir():
norecursepatterns = config.getini("norecursedirs")
norecursepatterns = config.getini('norecursedirs')
if any(fnmatch_ex(pat, collection_path) for pat in norecursepatterns):
return True
@ -428,13 +427,13 @@ def pytest_ignore_collect(collection_path: Path, config: Config) -> Optional[boo
def pytest_collect_directory(
path: Path, parent: nodes.Collector
) -> Optional[nodes.Collector]:
path: Path, parent: nodes.Collector,
) -> nodes.Collector | None:
return Dir.from_parent(parent, path=path)
def pytest_collection_modifyitems(items: List[nodes.Item], config: Config) -> None:
deselect_prefixes = tuple(config.getoption("deselect") or [])
def pytest_collection_modifyitems(items: list[nodes.Item], config: Config) -> None:
deselect_prefixes = tuple(config.getoption('deselect') or [])
if not deselect_prefixes:
return
@ -469,7 +468,7 @@ class FSHookProxy:
class Interrupted(KeyboardInterrupt):
"""Signals that the test run was interrupted."""
__module__ = "builtins" # For py3.
__module__ = 'builtins' # For py3.
class Failed(Exception):
@ -478,7 +477,7 @@ class Failed(Exception):
@dataclasses.dataclass
class _bestrelpath_cache(Dict[Path, str]):
__slots__ = ("path",)
__slots__ = ('path',)
path: Path
@ -507,7 +506,7 @@ class Dir(nodes.Directory):
parent: nodes.Collector,
*,
path: Path,
) -> "Self":
) -> Self:
"""The public constructor.
:param parent: The parent collector of this Dir.
@ -515,9 +514,9 @@ class Dir(nodes.Directory):
"""
return super().from_parent(parent=parent, path=path)
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
def collect(self) -> Iterable[nodes.Item | nodes.Collector]:
config = self.config
col: Optional[nodes.Collector]
col: nodes.Collector | None
cols: Sequence[nodes.Collector]
ihook = self.ihook
for direntry in scandir(self.path):
@ -552,60 +551,60 @@ class Session(nodes.Collector):
_setupstate: SetupState
# Set on the session by fixtures.pytest_sessionstart.
_fixturemanager: FixtureManager
exitstatus: Union[int, ExitCode]
exitstatus: int | ExitCode
def __init__(self, config: Config) -> None:
super().__init__(
name="",
name='',
path=config.rootpath,
fspath=None,
parent=None,
config=config,
session=self,
nodeid="",
nodeid='',
)
self.testsfailed = 0
self.testscollected = 0
self._shouldstop: Union[bool, str] = False
self._shouldfail: Union[bool, str] = False
self.trace = config.trace.root.get("collection")
self._initialpaths: FrozenSet[Path] = frozenset()
self._initialpaths_with_parents: FrozenSet[Path] = frozenset()
self._notfound: List[Tuple[str, Sequence[nodes.Collector]]] = []
self._initial_parts: List[CollectionArgument] = []
self._collection_cache: Dict[nodes.Collector, CollectReport] = {}
self.items: List[nodes.Item] = []
self._shouldstop: bool | str = False
self._shouldfail: bool | str = False
self.trace = config.trace.root.get('collection')
self._initialpaths: frozenset[Path] = frozenset()
self._initialpaths_with_parents: frozenset[Path] = frozenset()
self._notfound: list[tuple[str, Sequence[nodes.Collector]]] = []
self._initial_parts: list[CollectionArgument] = []
self._collection_cache: dict[nodes.Collector, CollectReport] = {}
self.items: list[nodes.Item] = []
self._bestrelpathcache: Dict[Path, str] = _bestrelpath_cache(config.rootpath)
self._bestrelpathcache: dict[Path, str] = _bestrelpath_cache(config.rootpath)
self.config.pluginmanager.register(self, name="session")
self.config.pluginmanager.register(self, name='session')
@classmethod
def from_config(cls, config: Config) -> "Session":
def from_config(cls, config: Config) -> Session:
session: Session = cls._create(config=config)
return session
def __repr__(self) -> str:
return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % (
return '<%s %s exitstatus=%r testsfailed=%d testscollected=%d>' % (
self.__class__.__name__,
self.name,
getattr(self, "exitstatus", "<UNSET>"),
getattr(self, 'exitstatus', '<UNSET>'),
self.testsfailed,
self.testscollected,
)
@property
def shouldstop(self) -> Union[bool, str]:
def shouldstop(self) -> bool | str:
return self._shouldstop
@shouldstop.setter
def shouldstop(self, value: Union[bool, str]) -> None:
def shouldstop(self, value: bool | str) -> None:
# The runner checks shouldfail and assumes that if it is set we are
# definitely stopping, so prevent unsetting it.
if value is False and self._shouldstop:
warnings.warn(
PytestWarning(
"session.shouldstop cannot be unset after it has been set; ignoring."
'session.shouldstop cannot be unset after it has been set; ignoring.',
),
stacklevel=2,
)
@ -613,17 +612,17 @@ class Session(nodes.Collector):
self._shouldstop = value
@property
def shouldfail(self) -> Union[bool, str]:
def shouldfail(self) -> bool | str:
return self._shouldfail
@shouldfail.setter
def shouldfail(self, value: Union[bool, str]) -> None:
def shouldfail(self, value: bool | str) -> None:
# The runner checks shouldfail and assumes that if it is set we are
# definitely stopping, so prevent unsetting it.
if value is False and self._shouldfail:
warnings.warn(
PytestWarning(
"session.shouldfail cannot be unset after it has been set; ignoring."
'session.shouldfail cannot be unset after it has been set; ignoring.',
),
stacklevel=2,
)
@ -651,19 +650,19 @@ class Session(nodes.Collector):
@hookimpl(tryfirst=True)
def pytest_runtest_logreport(
self, report: Union[TestReport, CollectReport]
self, report: TestReport | CollectReport,
) -> None:
if report.failed and not hasattr(report, "wasxfail"):
if report.failed and not hasattr(report, 'wasxfail'):
self.testsfailed += 1
maxfail = self.config.getvalue("maxfail")
maxfail = self.config.getvalue('maxfail')
if maxfail and self.testsfailed >= maxfail:
self.shouldfail = "stopping after %d failures" % (self.testsfailed)
self.shouldfail = 'stopping after %d failures' % (self.testsfailed)
pytest_collectreport = pytest_runtest_logreport
def isinitpath(
self,
path: Union[str, "os.PathLike[str]"],
path: str | os.PathLike[str],
*,
with_parents: bool = False,
) -> bool:
@ -685,7 +684,7 @@ class Session(nodes.Collector):
else:
return path_ in self._initialpaths
def gethookproxy(self, fspath: "os.PathLike[str]") -> pluggy.HookRelay:
def gethookproxy(self, fspath: os.PathLike[str]) -> pluggy.HookRelay:
# Optimization: Path(Path(...)) is much slower than isinstance.
path = fspath if isinstance(fspath, Path) else Path(fspath)
pm = self.config.pluginmanager
@ -705,7 +704,7 @@ class Session(nodes.Collector):
def _collect_path(
self,
path: Path,
path_cache: Dict[Path, Sequence[nodes.Collector]],
path_cache: dict[Path, Sequence[nodes.Collector]],
) -> Sequence[nodes.Collector]:
"""Create a Collector for the given path.
@ -717,8 +716,8 @@ class Session(nodes.Collector):
if path.is_dir():
ihook = self.gethookproxy(path.parent)
col: Optional[nodes.Collector] = ihook.pytest_collect_directory(
path=path, parent=self
col: nodes.Collector | None = ihook.pytest_collect_directory(
path=path, parent=self,
)
cols: Sequence[nodes.Collector] = (col,) if col is not None else ()
@ -735,19 +734,19 @@ class Session(nodes.Collector):
@overload
def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ...
self, args: Sequence[str] | None = ..., genitems: Literal[True] = ...,
) -> Sequence[nodes.Item]:
...
@overload
def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: bool = ...
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
self, args: Sequence[str] | None = ..., genitems: bool = ...,
) -> Sequence[nodes.Item | nodes.Collector]:
...
def perform_collect(
self, args: Optional[Sequence[str]] = None, genitems: bool = True
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
self, args: Sequence[str] | None = None, genitems: bool = True,
) -> Sequence[nodes.Item | nodes.Collector]:
"""Perform the collection phase for this session.
This is called by the default :hook:`pytest_collection` hook
@ -764,7 +763,7 @@ class Session(nodes.Collector):
if args is None:
args = self.config.args
self.trace("perform_collect", self, args)
self.trace('perform_collect', self, args)
self.trace.root.indent += 1
hook = self.config.hook
@ -773,10 +772,10 @@ class Session(nodes.Collector):
self._initial_parts = []
self._collection_cache = {}
self.items = []
items: Sequence[Union[nodes.Item, nodes.Collector]] = self.items
items: Sequence[nodes.Item | nodes.Collector] = self.items
try:
initialpaths: List[Path] = []
initialpaths_with_parents: List[Path] = []
initialpaths: list[Path] = []
initialpaths_with_parents: list[Path] = []
for arg in args:
collection_argument = resolve_collection_argument(
self.config.invocation_params.dir,
@ -798,10 +797,10 @@ class Session(nodes.Collector):
for arg, collectors in self._notfound:
if collectors:
errors.append(
f"not found: {arg}\n(no match in any of {collectors!r})"
f'not found: {arg}\n(no match in any of {collectors!r})',
)
else:
errors.append(f"found no collectors for {arg}")
errors.append(f'found no collectors for {arg}')
raise UsageError(*errors)
@ -814,7 +813,7 @@ class Session(nodes.Collector):
self.config.pluginmanager.check_pending()
hook.pytest_collection_modifyitems(
session=self, config=self.config, items=items
session=self, config=self.config, items=items,
)
finally:
self._notfound = []
@ -831,7 +830,7 @@ class Session(nodes.Collector):
self,
node: nodes.Collector,
handle_dupes: bool = True,
) -> Tuple[CollectReport, bool]:
) -> tuple[CollectReport, bool]:
if node in self._collection_cache and handle_dupes:
rep = self._collection_cache[node]
return rep, True
@ -840,16 +839,16 @@ class Session(nodes.Collector):
self._collection_cache[node] = rep
return rep, False
def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]:
def collect(self) -> Iterator[nodes.Item | nodes.Collector]:
# This is a cache for the root directories of the initial paths.
# We can't use collection_cache for Session because of its special
# role as the bootstrapping collector.
path_cache: Dict[Path, Sequence[nodes.Collector]] = {}
path_cache: dict[Path, Sequence[nodes.Collector]] = {}
pm = self.config.pluginmanager
for collection_argument in self._initial_parts:
self.trace("processing argument", collection_argument)
self.trace('processing argument', collection_argument)
self.trace.root.indent += 1
argpath = collection_argument.path
@ -858,7 +857,7 @@ class Session(nodes.Collector):
# resolve_collection_argument() ensures this.
if argpath.is_dir():
assert not names, f"invalid arg {(argpath, names)!r}"
assert not names, f'invalid arg {(argpath, names)!r}'
paths = [argpath]
# Add relevant parents of the path, from the root, e.g.
@ -872,7 +871,7 @@ class Session(nodes.Collector):
else:
# For --pyargs arguments, only consider paths matching the module
# name. Paths beyond the package hierarchy are not included.
module_name_parts = module_name.split(".")
module_name_parts = module_name.split('.')
for i, path in enumerate(argpath.parents, 2):
if i > len(module_name_parts) or path.stem != module_name_parts[-i]:
break
@ -882,8 +881,8 @@ class Session(nodes.Collector):
# and discarding all nodes which don't match the level's part.
any_matched_in_initial_part = False
notfound_collectors = []
work: List[
Tuple[Union[nodes.Collector, nodes.Item], List[Union[Path, str]]]
work: list[
tuple[nodes.Collector | nodes.Item, list[Path | str]]
] = [(self, [*paths, *names])]
while work:
matchnode, matchparts = work.pop()
@ -901,7 +900,7 @@ class Session(nodes.Collector):
# Collect this level of matching.
# Collecting Session (self) is done directly to avoid endless
# recursion to this function.
subnodes: Sequence[Union[nodes.Collector, nodes.Item]]
subnodes: Sequence[nodes.Collector | nodes.Item]
if isinstance(matchnode, Session):
assert isinstance(matchparts[0], Path)
subnodes = matchnode._collect_path(matchparts[0], path_cache)
@ -909,9 +908,9 @@ class Session(nodes.Collector):
# For backward compat, files given directly multiple
# times on the command line should not be deduplicated.
handle_dupes = not (
len(matchparts) == 1
and isinstance(matchparts[0], Path)
and matchparts[0].is_file()
len(matchparts) == 1 and
isinstance(matchparts[0], Path) and
matchparts[0].is_file()
)
rep, duplicate = self._collect_one_node(matchnode, handle_dupes)
if not duplicate and not rep.passed:
@ -929,15 +928,15 @@ class Session(nodes.Collector):
# Path part e.g. `/a/b/` in `/a/b/test_file.py::TestIt::test_it`.
if isinstance(matchparts[0], Path):
is_match = node.path == matchparts[0]
if sys.platform == "win32" and not is_match:
if sys.platform == 'win32' and not is_match:
# In case the file paths do not match, fallback to samefile() to
# account for short-paths on Windows (#11895).
same_file = os.path.samefile(node.path, matchparts[0])
# We don't want to match links to the current node,
# otherwise we would match the same file more than once (#12039).
is_match = same_file and (
os.path.islink(node.path)
== os.path.islink(matchparts[0])
os.path.islink(node.path) ==
os.path.islink(matchparts[0])
)
# Name part e.g. `TestIt` in `/a/b/test_file.py::TestIt::test_it`.
@ -945,8 +944,8 @@ class Session(nodes.Collector):
# TODO: Remove parametrized workaround once collection structure contains
# parametrization.
is_match = (
node.name == matchparts[0]
or node.name.split("[")[0] == matchparts[0]
node.name == matchparts[0] or
node.name.split('[')[0] == matchparts[0]
)
if is_match:
work.append((node, matchparts[1:]))
@ -956,21 +955,21 @@ class Session(nodes.Collector):
notfound_collectors.append(matchnode)
if not any_matched_in_initial_part:
report_arg = "::".join((str(argpath), *names))
report_arg = '::'.join((str(argpath), *names))
self._notfound.append((report_arg, notfound_collectors))
self.trace.root.indent -= 1
def genitems(
self, node: Union[nodes.Item, nodes.Collector]
self, node: nodes.Item | nodes.Collector,
) -> Iterator[nodes.Item]:
self.trace("genitems", node)
self.trace('genitems', node)
if isinstance(node, nodes.Item):
node.ihook.pytest_itemcollected(item=node)
yield node
else:
assert isinstance(node, nodes.Collector)
keepduplicates = self.config.getoption("keepduplicates")
keepduplicates = self.config.getoption('keepduplicates')
# For backward compat, dedup only applies to files.
handle_dupes = not (keepduplicates and isinstance(node, nodes.File))
rep, duplicate = self._collect_one_node(node, handle_dupes)
@ -983,7 +982,7 @@ class Session(nodes.Collector):
node.ihook.pytest_collectreport(report=rep)
def search_pypath(module_name: str) -> Optional[str]:
def search_pypath(module_name: str) -> str | None:
"""Search sys.path for the given a dotted module name, and return its file
system path if found."""
try:
@ -993,7 +992,7 @@ def search_pypath(module_name: str) -> Optional[str]:
# ValueError: not a module name
except (AttributeError, ImportError, ValueError):
return None
if spec is None or spec.origin is None or spec.origin == "namespace":
if spec is None or spec.origin is None or spec.origin == 'namespace':
return None
elif spec.submodule_search_locations:
return os.path.dirname(spec.origin)
@ -1007,11 +1006,11 @@ class CollectionArgument:
path: Path
parts: Sequence[str]
module_name: Optional[str]
module_name: str | None
def resolve_collection_argument(
invocation_path: Path, arg: str, *, as_pypath: bool = False
invocation_path: Path, arg: str, *, as_pypath: bool = False,
) -> CollectionArgument:
"""Parse path arguments optionally containing selection parts and return (fspath, names).
@ -1045,10 +1044,10 @@ def resolve_collection_argument(
If the path doesn't exist, raise UsageError.
If the path is a directory and selection parts are present, raise UsageError.
"""
base, squacket, rest = str(arg).partition("[")
strpath, *parts = base.split("::")
base, squacket, rest = str(arg).partition('[')
strpath, *parts = base.split('::')
if parts:
parts[-1] = f"{parts[-1]}{squacket}{rest}"
parts[-1] = f'{parts[-1]}{squacket}{rest}'
module_name = None
if as_pypath:
pyarg_strpath = search_pypath(strpath)
@ -1059,16 +1058,16 @@ def resolve_collection_argument(
fspath = absolutepath(fspath)
if not safe_exists(fspath):
msg = (
"module or package not found: {arg} (missing __init__.py?)"
'module or package not found: {arg} (missing __init__.py?)'
if as_pypath
else "file or directory not found: {arg}"
else 'file or directory not found: {arg}'
)
raise UsageError(msg.format(arg=arg))
if parts and fspath.is_dir():
msg = (
"package argument cannot contain :: selection parts: {arg}"
'package argument cannot contain :: selection parts: {arg}'
if as_pypath
else "directory argument cannot contain :: selection parts: {arg}"
else 'directory argument cannot contain :: selection parts: {arg}'
)
raise UsageError(msg.format(arg=arg))
return CollectionArgument(

View file

@ -1,4 +1,5 @@
"""Generic mechanism for marking and selecting python functions."""
from __future__ import annotations
import dataclasses
from typing import AbstractSet
@ -8,6 +9,13 @@ from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey
from .expression import Expression
from .expression import ParseError
from .structures import EMPTY_PARAMETERSET_OPTION
@ -17,12 +25,6 @@ from .structures import MARK_GEN
from .structures import MarkDecorator
from .structures import MarkGenerator
from .structures import ParameterSet
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey
if TYPE_CHECKING:
@ -30,12 +32,12 @@ if TYPE_CHECKING:
__all__ = [
"MARK_GEN",
"Mark",
"MarkDecorator",
"MarkGenerator",
"ParameterSet",
"get_empty_parameterset_mark",
'MARK_GEN',
'Mark',
'MarkDecorator',
'MarkGenerator',
'ParameterSet',
'get_empty_parameterset_mark',
]
@ -44,8 +46,8 @@ old_mark_config_key = StashKey[Optional[Config]]()
def param(
*values: object,
marks: Union[MarkDecorator, Collection[Union[MarkDecorator, Mark]]] = (),
id: Optional[str] = None,
marks: MarkDecorator | Collection[MarkDecorator | Mark] = (),
id: str | None = None,
) -> ParameterSet:
"""Specify a parameter in `pytest.mark.parametrize`_ calls or
:ref:`parametrized fixtures <fixture-parametrize-marks>`.
@ -70,59 +72,59 @@ def param(
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group = parser.getgroup('general')
group._addoption(
"-k",
action="store",
dest="keyword",
default="",
metavar="EXPRESSION",
help="Only run tests which match the given substring expression. "
"An expression is a Python evaluatable expression "
"where all names are substring-matched against test names "
'-k',
action='store',
dest='keyword',
default='',
metavar='EXPRESSION',
help='Only run tests which match the given substring expression. '
'An expression is a Python evaluatable expression '
'where all names are substring-matched against test names '
"and their parent classes. Example: -k 'test_method or test_"
"other' matches all test functions and classes whose name "
"contains 'test_method' or 'test_other', while -k 'not test_method' "
"matches those that don't contain 'test_method' in their names. "
"-k 'not test_method and not test_other' will eliminate the matches. "
"Additionally keywords are matched to classes and functions "
'Additionally keywords are matched to classes and functions '
"containing extra names in their 'extra_keyword_matches' set, "
"as well as functions which have names assigned directly to them. "
"The matching is case-insensitive.",
'as well as functions which have names assigned directly to them. '
'The matching is case-insensitive.',
)
group._addoption(
"-m",
action="store",
dest="markexpr",
default="",
metavar="MARKEXPR",
help="Only run tests matching given mark expression. "
'-m',
action='store',
dest='markexpr',
default='',
metavar='MARKEXPR',
help='Only run tests matching given mark expression. '
"For example: -m 'mark1 and not mark2'.",
)
group.addoption(
"--markers",
action="store_true",
help="show markers (builtin, plugin and per-project ones).",
'--markers',
action='store_true',
help='show markers (builtin, plugin and per-project ones).',
)
parser.addini("markers", "Register new markers for test functions", "linelist")
parser.addini(EMPTY_PARAMETERSET_OPTION, "Default marker for empty parametersets")
parser.addini('markers', 'Register new markers for test functions', 'linelist')
parser.addini(EMPTY_PARAMETERSET_OPTION, 'Default marker for empty parametersets')
@hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
import _pytest.config
if config.option.markers:
config._do_configure()
tw = _pytest.config.create_terminal_writer(config)
for line in config.getini("markers"):
parts = line.split(":", 1)
for line in config.getini('markers'):
parts = line.split(':', 1)
name = parts[0]
rest = parts[1] if len(parts) == 2 else ""
tw.write("@pytest.mark.%s:" % name, bold=True)
rest = parts[1] if len(parts) == 2 else ''
tw.write('@pytest.mark.%s:' % name, bold=True)
tw.line(rest)
tw.line()
config._ensure_unconfigure()
@ -146,12 +148,12 @@ class KeywordMatcher:
any item, as well as names directly assigned to test functions.
"""
__slots__ = ("_names",)
__slots__ = ('_names',)
_names: AbstractSet[str]
@classmethod
def from_item(cls, item: "Item") -> "KeywordMatcher":
def from_item(cls, item: Item) -> KeywordMatcher:
mapped_names = set()
# Add the names of the current item and any parent items,
@ -163,7 +165,7 @@ class KeywordMatcher:
if isinstance(node, pytest.Session):
continue
if isinstance(node, pytest.Directory) and isinstance(
node.parent, pytest.Session
node.parent, pytest.Session,
):
continue
mapped_names.add(node.name)
@ -172,7 +174,7 @@ class KeywordMatcher:
mapped_names.update(item.listextrakeywords())
# Add the names attached to the current function through direct assignment.
function_obj = getattr(item, "function", None)
function_obj = getattr(item, 'function', None)
if function_obj:
mapped_names.update(function_obj.__dict__)
@ -191,7 +193,7 @@ class KeywordMatcher:
return False
def deselect_by_keyword(items: "List[Item]", config: Config) -> None:
def deselect_by_keyword(items: List[Item], config: Config) -> None:
keywordexpr = config.option.keyword.lstrip()
if not keywordexpr:
return
@ -218,12 +220,12 @@ class MarkMatcher:
Tries to match on any marker names, attached to the given colitem.
"""
__slots__ = ("own_mark_names",)
__slots__ = ('own_mark_names',)
own_mark_names: AbstractSet[str]
@classmethod
def from_item(cls, item: "Item") -> "MarkMatcher":
def from_item(cls, item: Item) -> MarkMatcher:
mark_names = {mark.name for mark in item.iter_markers()}
return cls(mark_names)
@ -231,14 +233,14 @@ class MarkMatcher:
return name in self.own_mark_names
def deselect_by_mark(items: "List[Item]", config: Config) -> None:
def deselect_by_mark(items: List[Item], config: Config) -> None:
matchexpr = config.option.markexpr
if not matchexpr:
return
expr = _parse_expression(matchexpr, "Wrong expression passed to '-m'")
remaining: List[Item] = []
deselected: List[Item] = []
remaining: list[Item] = []
deselected: list[Item] = []
for item in items:
if expr.evaluate(MarkMatcher.from_item(item)):
remaining.append(item)
@ -253,10 +255,10 @@ def _parse_expression(expr: str, exc_message: str) -> Expression:
try:
return Expression.compile(expr)
except ParseError as e:
raise UsageError(f"{exc_message}: {expr}: {e}") from None
raise UsageError(f'{exc_message}: {expr}: {e}') from None
def pytest_collection_modifyitems(items: "List[Item]", config: Config) -> None:
def pytest_collection_modifyitems(items: List[Item], config: Config) -> None:
deselect_by_keyword(items, config)
deselect_by_mark(items, config)
@ -267,10 +269,10 @@ def pytest_configure(config: Config) -> None:
empty_parameterset = config.getini(EMPTY_PARAMETERSET_OPTION)
if empty_parameterset not in ("skip", "xfail", "fail_at_collect", None, ""):
if empty_parameterset not in ('skip', 'xfail', 'fail_at_collect', None, ''):
raise UsageError(
f"{EMPTY_PARAMETERSET_OPTION!s} must be one of skip, xfail or fail_at_collect"
f" but it is {empty_parameterset!r}"
f'{EMPTY_PARAMETERSET_OPTION!s} must be one of skip, xfail or fail_at_collect'
f' but it is {empty_parameterset!r}',
)

View file

@ -14,6 +14,7 @@ The semantics are:
- ident evaluates to True of False according to a provided matcher function.
- or/and/not evaluate according to the usual boolean semantics.
"""
from __future__ import annotations
import ast
import dataclasses
@ -29,24 +30,24 @@ from typing import Sequence
__all__ = [
"Expression",
"ParseError",
'Expression',
'ParseError',
]
class TokenType(enum.Enum):
LPAREN = "left parenthesis"
RPAREN = "right parenthesis"
OR = "or"
AND = "and"
NOT = "not"
IDENT = "identifier"
EOF = "end of input"
LPAREN = 'left parenthesis'
RPAREN = 'right parenthesis'
OR = 'or'
AND = 'and'
NOT = 'not'
IDENT = 'identifier'
EOF = 'end of input'
@dataclasses.dataclass(frozen=True)
class Token:
__slots__ = ("type", "value", "pos")
__slots__ = ('type', 'value', 'pos')
type: TokenType
value: str
pos: int
@ -64,11 +65,11 @@ class ParseError(Exception):
self.message = message
def __str__(self) -> str:
return f"at column {self.column}: {self.message}"
return f'at column {self.column}: {self.message}'
class Scanner:
__slots__ = ("tokens", "current")
__slots__ = ('tokens', 'current')
def __init__(self, input: str) -> None:
self.tokens = self.lex(input)
@ -77,23 +78,23 @@ class Scanner:
def lex(self, input: str) -> Iterator[Token]:
pos = 0
while pos < len(input):
if input[pos] in (" ", "\t"):
if input[pos] in (' ', '\t'):
pos += 1
elif input[pos] == "(":
yield Token(TokenType.LPAREN, "(", pos)
elif input[pos] == '(':
yield Token(TokenType.LPAREN, '(', pos)
pos += 1
elif input[pos] == ")":
yield Token(TokenType.RPAREN, ")", pos)
elif input[pos] == ')':
yield Token(TokenType.RPAREN, ')', pos)
pos += 1
else:
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:])
match = re.match(r'(:?\w|:|\+|-|\.|\[|\]|\\|/)+', input[pos:])
if match:
value = match.group(0)
if value == "or":
if value == 'or':
yield Token(TokenType.OR, value, pos)
elif value == "and":
elif value == 'and':
yield Token(TokenType.AND, value, pos)
elif value == "not":
elif value == 'not':
yield Token(TokenType.NOT, value, pos)
else:
yield Token(TokenType.IDENT, value, pos)
@ -103,9 +104,9 @@ class Scanner:
pos + 1,
f'unexpected character "{input[pos]}"',
)
yield Token(TokenType.EOF, "", pos)
yield Token(TokenType.EOF, '', pos)
def accept(self, type: TokenType, *, reject: bool = False) -> Optional[Token]:
def accept(self, type: TokenType, *, reject: bool = False) -> Token | None:
if self.current.type is type:
token = self.current
if token.type is not TokenType.EOF:
@ -118,8 +119,8 @@ class Scanner:
def reject(self, expected: Sequence[TokenType]) -> NoReturn:
raise ParseError(
self.current.pos + 1,
"expected {}; got {}".format(
" OR ".join(type.value for type in expected),
'expected {}; got {}'.format(
' OR '.join(type.value for type in expected),
self.current.type.value,
),
)
@ -128,7 +129,7 @@ class Scanner:
# True, False and None are legal match expression identifiers,
# but illegal as Python identifiers. To fix this, this prefix
# is added to identifiers in the conversion to Python AST.
IDENT_PREFIX = "$"
IDENT_PREFIX = '$'
def expression(s: Scanner) -> ast.Expression:
@ -176,7 +177,7 @@ class MatcherAdapter(Mapping[str, bool]):
self.matcher = matcher
def __getitem__(self, key: str) -> bool:
return self.matcher(key[len(IDENT_PREFIX) :])
return self.matcher(key[len(IDENT_PREFIX):])
def __iter__(self) -> Iterator[str]:
raise NotImplementedError()
@ -191,13 +192,13 @@ class Expression:
The expression can be evaluated against different matchers.
"""
__slots__ = ("code",)
__slots__ = ('code',)
def __init__(self, code: types.CodeType) -> None:
self.code = code
@classmethod
def compile(self, input: str) -> "Expression":
def compile(self, input: str) -> Expression:
"""Compile a match expression.
:param input: The input expression - one line.
@ -205,8 +206,8 @@ class Expression:
astexpr = expression(Scanner(input))
code: types.CodeType = compile(
astexpr,
filename="<pytest match expression>",
mode="eval",
filename='<pytest match expression>',
mode='eval',
)
return Expression(code)
@ -219,5 +220,5 @@ class Expression:
:returns: Whether the expression matches or not.
"""
ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher))
ret: bool = eval(self.code, {'__builtins__': {}}, MatcherAdapter(matcher))
return ret

View file

@ -1,7 +1,10 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import collections.abc
import dataclasses
import inspect
import warnings
from typing import Any
from typing import Callable
from typing import Collection
@ -21,37 +24,37 @@ from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import warnings
from .._code import getfslineno
from ..compat import ascii_escaped
from ..compat import NOTSET
from ..compat import NotSetType
from _pytest.config import Config
from _pytest.deprecated import check_ispytest
from _pytest.deprecated import MARKED_FIXTURE
from _pytest.outcomes import fail
from _pytest.warning_types import PytestUnknownMarkWarning
from .._code import getfslineno
from ..compat import ascii_escaped
from ..compat import NOTSET
from ..compat import NotSetType
if TYPE_CHECKING:
from ..nodes import Node
EMPTY_PARAMETERSET_OPTION = "empty_parameter_set_mark"
EMPTY_PARAMETERSET_OPTION = 'empty_parameter_set_mark'
def istestfunc(func) -> bool:
return callable(func) and getattr(func, "__name__", "<lambda>") != "<lambda>"
return callable(func) and getattr(func, '__name__', '<lambda>') != '<lambda>'
def get_empty_parameterset_mark(
config: Config, argnames: Sequence[str], func
) -> "MarkDecorator":
config: Config, argnames: Sequence[str], func,
) -> MarkDecorator:
from ..nodes import Collector
fs, lineno = getfslineno(func)
reason = "got empty parameter set %r, function %s at %s:%d" % (
reason = 'got empty parameter set %r, function %s at %s:%d' % (
argnames,
func.__name__,
fs,
@ -59,15 +62,15 @@ def get_empty_parameterset_mark(
)
requested_mark = config.getini(EMPTY_PARAMETERSET_OPTION)
if requested_mark in ("", None, "skip"):
if requested_mark in ('', None, 'skip'):
mark = MARK_GEN.skip(reason=reason)
elif requested_mark == "xfail":
elif requested_mark == 'xfail':
mark = MARK_GEN.xfail(reason=reason, run=False)
elif requested_mark == "fail_at_collect":
elif requested_mark == 'fail_at_collect':
f_name = func.__name__
_, lineno = getfslineno(func)
raise Collector.CollectError(
"Empty parameter set in '%s' at line %d" % (f_name, lineno + 1)
"Empty parameter set in '%s' at line %d" % (f_name, lineno + 1),
)
else:
raise LookupError(requested_mark)
@ -75,17 +78,17 @@ def get_empty_parameterset_mark(
class ParameterSet(NamedTuple):
values: Sequence[Union[object, NotSetType]]
marks: Collection[Union["MarkDecorator", "Mark"]]
id: Optional[str]
values: Sequence[object | NotSetType]
marks: Collection[MarkDecorator | Mark]
id: str | None
@classmethod
def param(
cls,
*values: object,
marks: Union["MarkDecorator", Collection[Union["MarkDecorator", "Mark"]]] = (),
id: Optional[str] = None,
) -> "ParameterSet":
marks: MarkDecorator | Collection[MarkDecorator | Mark] = (),
id: str | None = None,
) -> ParameterSet:
if isinstance(marks, MarkDecorator):
marks = (marks,)
else:
@ -93,16 +96,16 @@ class ParameterSet(NamedTuple):
if id is not None:
if not isinstance(id, str):
raise TypeError(f"Expected id to be a string, got {type(id)}: {id!r}")
raise TypeError(f'Expected id to be a string, got {type(id)}: {id!r}')
id = ascii_escaped(id)
return cls(values, marks, id)
@classmethod
def extract_from(
cls,
parameterset: Union["ParameterSet", Sequence[object], object],
parameterset: ParameterSet | Sequence[object] | object,
force_tuple: bool = False,
) -> "ParameterSet":
) -> ParameterSet:
"""Extract from an object or objects.
:param parameterset:
@ -127,13 +130,13 @@ class ParameterSet(NamedTuple):
@staticmethod
def _parse_parametrize_args(
argnames: Union[str, Sequence[str]],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
argnames: str | Sequence[str],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
*args,
**kwargs,
) -> Tuple[Sequence[str], bool]:
) -> tuple[Sequence[str], bool]:
if isinstance(argnames, str):
argnames = [x.strip() for x in argnames.split(",") if x.strip()]
argnames = [x.strip() for x in argnames.split(',') if x.strip()]
force_tuple = len(argnames) == 1
else:
force_tuple = False
@ -141,9 +144,9 @@ class ParameterSet(NamedTuple):
@staticmethod
def _parse_parametrize_parameters(
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
force_tuple: bool,
) -> List["ParameterSet"]:
) -> list[ParameterSet]:
return [
ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
]
@ -151,12 +154,12 @@ class ParameterSet(NamedTuple):
@classmethod
def _for_parametrize(
cls,
argnames: Union[str, Sequence[str]],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
argnames: str | Sequence[str],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
func,
config: Config,
nodeid: str,
) -> Tuple[Sequence[str], List["ParameterSet"]]:
) -> tuple[Sequence[str], list[ParameterSet]]:
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
del argvalues
@ -167,9 +170,9 @@ class ParameterSet(NamedTuple):
if len(param.values) != len(argnames):
msg = (
'{nodeid}: in "parametrize" the number of names ({names_len}):\n'
" {names}\n"
"must be equal to the number of values ({values_len}):\n"
" {values}"
' {names}\n'
'must be equal to the number of values ({values_len}):\n'
' {values}'
)
fail(
msg.format(
@ -186,7 +189,7 @@ class ParameterSet(NamedTuple):
# parameter set with NOTSET values, with the "empty parameter set" mark applied to it.
mark = get_empty_parameterset_mark(config, argnames, func)
parameters.append(
ParameterSet(values=(NOTSET,) * len(argnames), marks=[mark], id=None)
ParameterSet(values=(NOTSET,) * len(argnames), marks=[mark], id=None),
)
return argnames, parameters
@ -199,40 +202,40 @@ class Mark:
#: Name of the mark.
name: str
#: Positional arguments of the mark decorator.
args: Tuple[Any, ...]
args: tuple[Any, ...]
#: Keyword arguments of the mark decorator.
kwargs: Mapping[str, Any]
#: Source Mark for ids with parametrize Marks.
_param_ids_from: Optional["Mark"] = dataclasses.field(default=None, repr=False)
_param_ids_from: Mark | None = dataclasses.field(default=None, repr=False)
#: Resolved/generated ids with parametrize Marks.
_param_ids_generated: Optional[Sequence[str]] = dataclasses.field(
default=None, repr=False
_param_ids_generated: Sequence[str] | None = dataclasses.field(
default=None, repr=False,
)
def __init__(
self,
name: str,
args: Tuple[Any, ...],
args: tuple[Any, ...],
kwargs: Mapping[str, Any],
param_ids_from: Optional["Mark"] = None,
param_ids_generated: Optional[Sequence[str]] = None,
param_ids_from: Mark | None = None,
param_ids_generated: Sequence[str] | None = None,
*,
_ispytest: bool = False,
) -> None:
""":meta private:"""
check_ispytest(_ispytest)
# Weirdness to bypass frozen=True.
object.__setattr__(self, "name", name)
object.__setattr__(self, "args", args)
object.__setattr__(self, "kwargs", kwargs)
object.__setattr__(self, "_param_ids_from", param_ids_from)
object.__setattr__(self, "_param_ids_generated", param_ids_generated)
object.__setattr__(self, 'name', name)
object.__setattr__(self, 'args', args)
object.__setattr__(self, 'kwargs', kwargs)
object.__setattr__(self, '_param_ids_from', param_ids_from)
object.__setattr__(self, '_param_ids_generated', param_ids_generated)
def _has_param_ids(self) -> bool:
return "ids" in self.kwargs or len(self.args) >= 4
return 'ids' in self.kwargs or len(self.args) >= 4
def combined_with(self, other: "Mark") -> "Mark":
def combined_with(self, other: Mark) -> Mark:
"""Return a new Mark which is a combination of this
Mark and another Mark.
@ -244,8 +247,8 @@ class Mark:
assert self.name == other.name
# Remember source of ids with parametrize Marks.
param_ids_from: Optional[Mark] = None
if self.name == "parametrize":
param_ids_from: Mark | None = None
if self.name == 'parametrize':
if other._has_param_ids():
param_ids_from = other
elif self._has_param_ids():
@ -263,7 +266,7 @@ class Mark:
# A generic parameter designating an object to which a Mark may
# be applied -- a test function (callable) or class.
# Note: a lambda is not allowed, but this can't be represented.
Markable = TypeVar("Markable", bound=Union[Callable[..., object], type])
Markable = TypeVar('Markable', bound=Union[Callable[..., object], type])
@dataclasses.dataclass
@ -315,7 +318,7 @@ class MarkDecorator:
return self.mark.name
@property
def args(self) -> Tuple[Any, ...]:
def args(self) -> tuple[Any, ...]:
"""Alias for mark.args."""
return self.mark.args
@ -329,7 +332,7 @@ class MarkDecorator:
""":meta private:"""
return self.name # for backward-compat (2.4.1 had this attr)
def with_args(self, *args: object, **kwargs: object) -> "MarkDecorator":
def with_args(self, *args: object, **kwargs: object) -> MarkDecorator:
"""Return a MarkDecorator with extra arguments added.
Unlike calling the MarkDecorator, with_args() can be used even
@ -346,7 +349,7 @@ class MarkDecorator:
pass
@overload
def __call__(self, *args: object, **kwargs: object) -> "MarkDecorator":
def __call__(self, *args: object, **kwargs: object) -> MarkDecorator:
pass
def __call__(self, *args: object, **kwargs: object):
@ -361,10 +364,10 @@ class MarkDecorator:
def get_unpacked_marks(
obj: Union[object, type],
obj: object | type,
*,
consider_mro: bool = True,
) -> List[Mark]:
) -> list[Mark]:
"""Obtain the unpacked marks that are stored on an object.
If obj is a class and consider_mro is true, return marks applied to
@ -373,10 +376,10 @@ def get_unpacked_marks(
"""
if isinstance(obj, type):
if not consider_mro:
mark_lists = [obj.__dict__.get("pytestmark", [])]
mark_lists = [obj.__dict__.get('pytestmark', [])]
else:
mark_lists = [
x.__dict__.get("pytestmark", []) for x in reversed(obj.__mro__)
x.__dict__.get('pytestmark', []) for x in reversed(obj.__mro__)
]
mark_list = []
for item in mark_lists:
@ -385,7 +388,7 @@ def get_unpacked_marks(
else:
mark_list.append(item)
else:
mark_attribute = getattr(obj, "pytestmark", [])
mark_attribute = getattr(obj, 'pytestmark', [])
if isinstance(mark_attribute, list):
mark_list = mark_attribute
else:
@ -394,7 +397,7 @@ def get_unpacked_marks(
def normalize_mark_list(
mark_list: Iterable[Union[Mark, MarkDecorator]],
mark_list: Iterable[Mark | MarkDecorator],
) -> Iterable[Mark]:
"""
Normalize an iterable of Mark or MarkDecorator objects into a list of marks
@ -404,9 +407,9 @@ def normalize_mark_list(
:returns: A new list of the extracted Mark objects
"""
for mark in mark_list:
mark_obj = getattr(mark, "mark", mark)
mark_obj = getattr(mark, 'mark', mark)
if not isinstance(mark_obj, Mark):
raise TypeError(f"got {mark_obj!r} instead of Mark")
raise TypeError(f'got {mark_obj!r} instead of Mark')
yield mark_obj
@ -438,14 +441,14 @@ if TYPE_CHECKING:
...
@overload
def __call__(self, reason: str = ...) -> "MarkDecorator":
def __call__(self, reason: str = ...) -> MarkDecorator:
...
class _SkipifMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override]
self,
condition: Union[str, bool] = ...,
*conditions: Union[str, bool],
condition: str | bool = ...,
*conditions: str | bool,
reason: str = ...,
) -> MarkDecorator:
...
@ -458,13 +461,13 @@ if TYPE_CHECKING:
@overload
def __call__(
self,
condition: Union[str, bool] = False,
*conditions: Union[str, bool],
condition: str | bool = False,
*conditions: str | bool,
reason: str = ...,
run: bool = ...,
raises: Union[
None, Type[BaseException], Tuple[Type[BaseException], ...]
] = ...,
raises: (
None | type[BaseException] | tuple[type[BaseException], ...]
) = ...,
strict: bool = ...,
) -> MarkDecorator:
...
@ -472,17 +475,15 @@ if TYPE_CHECKING:
class _ParametrizeMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override]
self,
argnames: Union[str, Sequence[str]],
argvalues: Iterable[Union[ParameterSet, Sequence[object], object]],
argnames: str | Sequence[str],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
*,
indirect: Union[bool, Sequence[str]] = ...,
ids: Optional[
Union[
Iterable[Union[None, str, float, int, bool]],
Callable[[Any], Optional[object]],
]
] = ...,
scope: Optional[_ScopeName] = ...,
indirect: bool | Sequence[str] = ...,
ids: None | (
Iterable[None | str | float | int | bool] |
Callable[[Any], object | None]
) = ...,
scope: _ScopeName | None = ...,
) -> MarkDecorator:
...
@ -523,24 +524,24 @@ class MarkGenerator:
def __init__(self, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
self._config: Optional[Config] = None
self._markers: Set[str] = set()
self._config: Config | None = None
self._markers: set[str] = set()
def __getattr__(self, name: str) -> MarkDecorator:
"""Generate a new :class:`MarkDecorator` with the given name."""
if name[0] == "_":
raise AttributeError("Marker name must NOT start with underscore")
if name[0] == '_':
raise AttributeError('Marker name must NOT start with underscore')
if self._config is not None:
# We store a set of markers as a performance optimisation - if a mark
# name is in the set we definitely know it, but a mark may be known and
# not in the set. We therefore start by updating the set!
if name not in self._markers:
for line in self._config.getini("markers"):
for line in self._config.getini('markers'):
# example lines: "skipif(condition): skip the given test if..."
# or "hypothesis: tests which use Hypothesis", so to get the
# marker name we split on both `:` and `(`.
marker = line.split(":")[0].split("(")[0].strip()
marker = line.split(':')[0].split('(')[0].strip()
self._markers.add(marker)
# If the name is not in the set of known marks after updating,
@ -548,19 +549,19 @@ class MarkGenerator:
if name not in self._markers:
if self._config.option.strict_markers or self._config.option.strict:
fail(
f"{name!r} not found in `markers` configuration option",
f'{name!r} not found in `markers` configuration option',
pytrace=False,
)
# Raise a specific error for common misspellings of "parametrize".
if name in ["parameterize", "parametrise", "parameterise"]:
if name in ['parameterize', 'parametrise', 'parameterise']:
__tracebackhide__ = True
fail(f"Unknown '{name}' mark, did you mean 'parametrize'?")
warnings.warn(
"Unknown pytest.mark.%s - is this a typo? You can register "
"custom marks to avoid this warning - for details, see "
"https://docs.pytest.org/en/stable/how-to/mark.html" % name,
'Unknown pytest.mark.%s - is this a typo? You can register '
'custom marks to avoid this warning - for details, see '
'https://docs.pytest.org/en/stable/how-to/mark.html' % name,
PytestUnknownMarkWarning,
2,
)
@ -573,9 +574,9 @@ MARK_GEN = MarkGenerator(_ispytest=True)
@final
class NodeKeywords(MutableMapping[str, Any]):
__slots__ = ("node", "parent", "_markers")
__slots__ = ('node', 'parent', '_markers')
def __init__(self, node: "Node") -> None:
def __init__(self, node: Node) -> None:
self.node = node
self.parent = node.parent
self._markers = {node.name: True}
@ -596,21 +597,21 @@ class NodeKeywords(MutableMapping[str, Any]):
def __contains__(self, key: object) -> bool:
return (
key in self._markers
or self.parent is not None
and key in self.parent.keywords
key in self._markers or
self.parent is not None and
key in self.parent.keywords
)
def update( # type: ignore[override]
self,
other: Union[Mapping[str, Any], Iterable[Tuple[str, Any]]] = (),
other: Mapping[str, Any] | Iterable[tuple[str, Any]] = (),
**kwds: Any,
) -> None:
self._markers.update(other)
self._markers.update(kwds)
def __delitem__(self, key: str) -> None:
raise ValueError("cannot delete key in keywords dict")
raise ValueError('cannot delete key in keywords dict')
def __iter__(self) -> Iterator[str]:
# Doesn't need to be fast.
@ -626,4 +627,4 @@ class NodeKeywords(MutableMapping[str, Any]):
return sum(1 for keyword in self)
def __repr__(self) -> str:
return f"<NodeKeywords for node {self.node}>"
return f'<NodeKeywords for node {self.node}>'

View file

@ -1,9 +1,12 @@
# mypy: allow-untyped-defs
"""Monkeypatching and mocking functionality."""
from contextlib import contextmanager
from __future__ import annotations
import os
import re
import sys
import warnings
from contextlib import contextmanager
from typing import Any
from typing import final
from typing import Generator
@ -15,21 +18,20 @@ from typing import overload
from typing import Tuple
from typing import TypeVar
from typing import Union
import warnings
from _pytest.fixtures import fixture
from _pytest.warning_types import PytestWarning
RE_IMPORT_ERROR_NAME = re.compile(r"^No module named (.*)$")
RE_IMPORT_ERROR_NAME = re.compile(r'^No module named (.*)$')
K = TypeVar("K")
V = TypeVar("V")
K = TypeVar('K')
V = TypeVar('V')
@fixture
def monkeypatch() -> Generator["MonkeyPatch", None, None]:
def monkeypatch() -> Generator[MonkeyPatch, None, None]:
"""A convenient fixture for monkey-patching.
The fixture provides these methods to modify objects, dictionaries, or
@ -60,12 +62,12 @@ def monkeypatch() -> Generator["MonkeyPatch", None, None]:
def resolve(name: str) -> object:
# Simplified from zope.dottedname.
parts = name.split(".")
parts = name.split('.')
used = parts.pop(0)
found: object = __import__(used)
for part in parts:
used += "." + part
used += '.' + part
try:
found = getattr(found, part)
except AttributeError:
@ -81,7 +83,7 @@ def resolve(name: str) -> object:
if expected == used:
raise
else:
raise ImportError(f"import error in {used}: {ex}") from ex
raise ImportError(f'import error in {used}: {ex}') from ex
found = annotated_getattr(found, part, used)
return found
@ -91,15 +93,15 @@ def annotated_getattr(obj: object, name: str, ann: str) -> object:
obj = getattr(obj, name)
except AttributeError as e:
raise AttributeError(
f"{type(obj).__name__!r} object at {ann} has no attribute {name!r}"
f'{type(obj).__name__!r} object at {ann} has no attribute {name!r}',
) from e
return obj
def derive_importpath(import_path: str, raising: bool) -> Tuple[str, object]:
if not isinstance(import_path, str) or "." not in import_path:
raise TypeError(f"must be absolute import path string, not {import_path!r}")
module, attr = import_path.rsplit(".", 1)
def derive_importpath(import_path: str, raising: bool) -> tuple[str, object]:
if not isinstance(import_path, str) or '.' not in import_path:
raise TypeError(f'must be absolute import path string, not {import_path!r}')
module, attr = import_path.rsplit('.', 1)
target = resolve(module)
if raising:
annotated_getattr(target, attr, ann=module)
@ -108,7 +110,7 @@ def derive_importpath(import_path: str, raising: bool) -> Tuple[str, object]:
class Notset:
def __repr__(self) -> str:
return "<notset>"
return '<notset>'
notset = Notset()
@ -129,14 +131,14 @@ class MonkeyPatch:
"""
def __init__(self) -> None:
self._setattr: List[Tuple[object, str, object]] = []
self._setitem: List[Tuple[Mapping[Any, Any], object, object]] = []
self._cwd: Optional[str] = None
self._savesyspath: Optional[List[str]] = None
self._setattr: list[tuple[object, str, object]] = []
self._setitem: list[tuple[Mapping[Any, Any], object, object]] = []
self._cwd: str | None = None
self._savesyspath: list[str] | None = None
@classmethod
@contextmanager
def context(cls) -> Generator["MonkeyPatch", None, None]:
def context(cls) -> Generator[MonkeyPatch, None, None]:
"""Context manager that returns a new :class:`MonkeyPatch` object
which undoes any patching done inside the ``with`` block upon exit.
@ -182,8 +184,8 @@ class MonkeyPatch:
def setattr(
self,
target: Union[str, object],
name: Union[object, str],
target: str | object,
name: object | str,
value: object = notset,
raising: bool = True,
) -> None:
@ -228,23 +230,23 @@ class MonkeyPatch:
if isinstance(value, Notset):
if not isinstance(target, str):
raise TypeError(
"use setattr(target, name, value) or "
"setattr(target, value) with target being a dotted "
"import string"
'use setattr(target, name, value) or '
'setattr(target, value) with target being a dotted '
'import string',
)
value = name
name, target = derive_importpath(target, raising)
else:
if not isinstance(name, str):
raise TypeError(
"use setattr(target, name, value) with name being a string or "
"setattr(target, value) with target being a dotted "
"import string"
'use setattr(target, name, value) with name being a string or '
'setattr(target, value) with target being a dotted '
'import string',
)
oldval = getattr(target, name, notset)
if raising and oldval is notset:
raise AttributeError(f"{target!r} has no attribute {name!r}")
raise AttributeError(f'{target!r} has no attribute {name!r}')
# avoid class descriptors like staticmethod/classmethod
if inspect.isclass(target):
@ -254,8 +256,8 @@ class MonkeyPatch:
def delattr(
self,
target: Union[object, str],
name: Union[str, Notset] = notset,
target: object | str,
name: str | Notset = notset,
raising: bool = True,
) -> None:
"""Delete attribute ``name`` from ``target``.
@ -273,9 +275,9 @@ class MonkeyPatch:
if isinstance(name, Notset):
if not isinstance(target, str):
raise TypeError(
"use delattr(target, name) or "
"delattr(target) with target being a dotted "
"import string"
'use delattr(target, name) or '
'delattr(target) with target being a dotted '
'import string',
)
name, target = derive_importpath(target, raising)
@ -310,7 +312,7 @@ class MonkeyPatch:
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
del dic[name] # type: ignore[attr-defined]
def setenv(self, name: str, value: str, prepend: Optional[str] = None) -> None:
def setenv(self, name: str, value: str, prepend: str | None = None) -> None:
"""Set environment variable ``name`` to ``value``.
If ``prepend`` is a character, read the current environment variable
@ -320,8 +322,8 @@ class MonkeyPatch:
if not isinstance(value, str):
warnings.warn( # type: ignore[unreachable]
PytestWarning(
f"Value of environment variable {name} type should be str, but got "
f"{value!r} (type: {type(value).__name__}); converted to str implicitly"
f'Value of environment variable {name} type should be str, but got '
f'{value!r} (type: {type(value).__name__}); converted to str implicitly',
),
stacklevel=2,
)
@ -347,7 +349,7 @@ class MonkeyPatch:
# https://github.com/pypa/setuptools/blob/d8b901bc/docs/pkg_resources.txt#L162-L171
# this is only needed when pkg_resources was already loaded by the namespace package
if "pkg_resources" in sys.modules:
if 'pkg_resources' in sys.modules:
from pkg_resources import fixup_namespace_packages
fixup_namespace_packages(str(path))
@ -363,7 +365,7 @@ class MonkeyPatch:
invalidate_caches()
def chdir(self, path: Union[str, "os.PathLike[str]"]) -> None:
def chdir(self, path: str | os.PathLike[str]) -> None:
"""Change the current working directory to the specified path.
:param path:

View file

@ -1,9 +1,12 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import abc
from functools import cached_property
from inspect import signature
import os
import pathlib
import warnings
from functools import cached_property
from inspect import signature
from pathlib import Path
from typing import Any
from typing import Callable
@ -21,11 +24,9 @@ from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import warnings
import pluggy
import _pytest._code
import pluggy
from _pytest._code import getfslineno
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import TerminalRepr
@ -53,18 +54,18 @@ if TYPE_CHECKING:
from _pytest.main import Session
SEP = "/"
SEP = '/'
tracebackcutdir = Path(_pytest.__file__).parent
_T = TypeVar("_T")
_T = TypeVar('_T')
def _imply_path(
node_type: Type["Node"],
path: Optional[Path],
fspath: Optional[LEGACY_PATH],
node_type: type[Node],
path: Path | None,
fspath: LEGACY_PATH | None,
) -> Path:
if fspath is not None:
warnings.warn(
@ -82,7 +83,7 @@ def _imply_path(
return Path(fspath)
_NodeType = TypeVar("_NodeType", bound="Node")
_NodeType = TypeVar('_NodeType', bound='Node')
class NodeMeta(abc.ABCMeta):
@ -102,28 +103,28 @@ class NodeMeta(abc.ABCMeta):
def __call__(cls, *k, **kw) -> NoReturn:
msg = (
"Direct construction of {name} has been deprecated, please use {name}.from_parent.\n"
"See "
"https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent"
" for more details."
).format(name=f"{cls.__module__}.{cls.__name__}")
'Direct construction of {name} has been deprecated, please use {name}.from_parent.\n'
'See '
'https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent'
' for more details.'
).format(name=f'{cls.__module__}.{cls.__name__}')
fail(msg, pytrace=False)
def _create(cls: Type[_T], *k, **kw) -> _T:
def _create(cls: type[_T], *k, **kw) -> _T:
try:
return super().__call__(*k, **kw) # type: ignore[no-any-return,misc]
except TypeError:
sig = signature(getattr(cls, "__init__"))
sig = signature(getattr(cls, '__init__'))
known_kw = {k: v for k, v in kw.items() if k in sig.parameters}
from .warning_types import PytestDeprecationWarning
warnings.warn(
PytestDeprecationWarning(
f"{cls} is not using a cooperative constructor and only takes {set(known_kw)}.\n"
"See https://docs.pytest.org/en/stable/deprecations.html"
"#constructors-of-custom-pytest-node-subclasses-should-take-kwargs "
"for more details."
)
f'{cls} is not using a cooperative constructor and only takes {set(known_kw)}.\n'
'See https://docs.pytest.org/en/stable/deprecations.html'
'#constructors-of-custom-pytest-node-subclasses-should-take-kwargs '
'for more details.',
),
)
return super().__call__(*k, **known_kw) # type: ignore[no-any-return,misc]
@ -147,25 +148,25 @@ class Node(abc.ABC, metaclass=NodeMeta):
# Use __slots__ to make attribute access faster.
# Note that __dict__ is still available.
__slots__ = (
"name",
"parent",
"config",
"session",
"path",
"_nodeid",
"_store",
"__dict__",
'name',
'parent',
'config',
'session',
'path',
'_nodeid',
'_store',
'__dict__',
)
def __init__(
self,
name: str,
parent: "Optional[Node]" = None,
config: Optional[Config] = None,
session: "Optional[Session]" = None,
fspath: Optional[LEGACY_PATH] = None,
path: Optional[Path] = None,
nodeid: Optional[str] = None,
parent: Optional[Node] = None,
config: Config | None = None,
session: Optional[Session] = None,
fspath: LEGACY_PATH | None = None,
path: Path | None = None,
nodeid: str | None = None,
) -> None:
#: A unique name within the scope of the parent node.
self.name: str = name
@ -178,7 +179,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
self.config: Config = config
else:
if not parent:
raise TypeError("config or parent must be provided")
raise TypeError('config or parent must be provided')
self.config = parent.config
if session:
@ -186,11 +187,11 @@ class Node(abc.ABC, metaclass=NodeMeta):
self.session: Session = session
else:
if not parent:
raise TypeError("session or parent must be provided")
raise TypeError('session or parent must be provided')
self.session = parent.session
if path is None and fspath is None:
path = getattr(parent, "path", None)
path = getattr(parent, 'path', None)
#: Filesystem path where this node was collected from (can be None).
self.path: pathlib.Path = _imply_path(type(self), path, fspath=fspath)
@ -199,18 +200,18 @@ class Node(abc.ABC, metaclass=NodeMeta):
self.keywords: MutableMapping[str, Any] = NodeKeywords(self)
#: The marker objects belonging to this node.
self.own_markers: List[Mark] = []
self.own_markers: list[Mark] = []
#: Allow adding of extra keywords to use for matching.
self.extra_keyword_matches: Set[str] = set()
self.extra_keyword_matches: set[str] = set()
if nodeid is not None:
assert "::()" not in nodeid
assert '::()' not in nodeid
self._nodeid = nodeid
else:
if not self.parent:
raise TypeError("nodeid or parent must be provided")
self._nodeid = self.parent.nodeid + "::" + self.name
raise TypeError('nodeid or parent must be provided')
self._nodeid = self.parent.nodeid + '::' + self.name
#: A place where plugins can store information on the node for their
#: own use.
@ -219,7 +220,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
self._store = self.stash
@classmethod
def from_parent(cls, parent: "Node", **kw) -> "Self":
def from_parent(cls, parent: Node, **kw) -> Self:
"""Public constructor for Nodes.
This indirection got introduced in order to enable removing
@ -230,10 +231,10 @@ class Node(abc.ABC, metaclass=NodeMeta):
:param parent: The parent node of this Node.
"""
if "config" in kw:
raise TypeError("config is not a valid argument for from_parent")
if "session" in kw:
raise TypeError("session is not a valid argument for from_parent")
if 'config' in kw:
raise TypeError('config is not a valid argument for from_parent')
if 'session' in kw:
raise TypeError('session is not a valid argument for from_parent')
return cls._create(parent=parent, **kw)
@property
@ -242,7 +243,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
return self.session.gethookproxy(self.path)
def __repr__(self) -> str:
return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None))
return '<{} {}>'.format(self.__class__.__name__, getattr(self, 'name', None))
def warn(self, warning: Warning) -> None:
"""Issue a warning for this Node.
@ -268,7 +269,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
# enforce type checks here to avoid getting a generic type error later otherwise.
if not isinstance(warning, Warning):
raise ValueError(
f"warning must be an instance of Warning or subclass, got {warning!r}"
f'warning must be an instance of Warning or subclass, got {warning!r}',
)
path, lineno = get_fslocation_from_item(self)
assert lineno is not None
@ -295,22 +296,22 @@ class Node(abc.ABC, metaclass=NodeMeta):
def teardown(self) -> None:
pass
def iter_parents(self) -> Iterator["Node"]:
def iter_parents(self) -> Iterator[Node]:
"""Iterate over all parent collectors starting from and including self
up to the root of the collection tree.
.. versionadded:: 8.1
"""
parent: Optional[Node] = self
parent: Node | None = self
while parent is not None:
yield parent
parent = parent.parent
def listchain(self) -> List["Node"]:
def listchain(self) -> list[Node]:
"""Return a list of all parent collectors starting from the root of the
collection tree down to and including self."""
chain = []
item: Optional[Node] = self
item: Node | None = self
while item is not None:
chain.append(item)
item = item.parent
@ -318,7 +319,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
return chain
def add_marker(
self, marker: Union[str, MarkDecorator], append: bool = True
self, marker: str | MarkDecorator, append: bool = True,
) -> None:
"""Dynamically add a marker object to the node.
@ -334,14 +335,14 @@ class Node(abc.ABC, metaclass=NodeMeta):
elif isinstance(marker, str):
marker_ = getattr(MARK_GEN, marker)
else:
raise ValueError("is not a string or pytest.mark.* Marker")
raise ValueError('is not a string or pytest.mark.* Marker')
self.keywords[marker_.name] = marker_
if append:
self.own_markers.append(marker_.mark)
else:
self.own_markers.insert(0, marker_.mark)
def iter_markers(self, name: Optional[str] = None) -> Iterator[Mark]:
def iter_markers(self, name: str | None = None) -> Iterator[Mark]:
"""Iterate over all markers of the node.
:param name: If given, filter the results by the name attribute.
@ -350,8 +351,8 @@ class Node(abc.ABC, metaclass=NodeMeta):
return (x[1] for x in self.iter_markers_with_node(name=name))
def iter_markers_with_node(
self, name: Optional[str] = None
) -> Iterator[Tuple["Node", Mark]]:
self, name: str | None = None,
) -> Iterator[tuple[Node, Mark]]:
"""Iterate over all markers of the node.
:param name: If given, filter the results by the name attribute.
@ -359,11 +360,11 @@ class Node(abc.ABC, metaclass=NodeMeta):
"""
for node in self.iter_parents():
for mark in node.own_markers:
if name is None or getattr(mark, "name", None) == name:
if name is None or getattr(mark, 'name', None) == name:
yield node, mark
@overload
def get_closest_marker(self, name: str) -> Optional[Mark]:
def get_closest_marker(self, name: str) -> Mark | None:
...
@overload
@ -371,8 +372,8 @@ class Node(abc.ABC, metaclass=NodeMeta):
...
def get_closest_marker(
self, name: str, default: Optional[Mark] = None
) -> Optional[Mark]:
self, name: str, default: Mark | None = None,
) -> Mark | None:
"""Return the first marker matching the name, from closest (for
example function) to farther level (for example module level).
@ -381,14 +382,14 @@ class Node(abc.ABC, metaclass=NodeMeta):
"""
return next(self.iter_markers(name=name), default)
def listextrakeywords(self) -> Set[str]:
def listextrakeywords(self) -> set[str]:
"""Return a set of all extra keywords in self and any parents."""
extra_keywords: Set[str] = set()
extra_keywords: set[str] = set()
for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches)
return extra_keywords
def listnames(self) -> List[str]:
def listnames(self) -> list[str]:
return [x.name for x in self.listchain()]
def addfinalizer(self, fin: Callable[[], object]) -> None:
@ -400,7 +401,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
"""
self.session._setupstate.addfinalizer(fin, self)
def getparent(self, cls: Type[_NodeType]) -> Optional[_NodeType]:
def getparent(self, cls: type[_NodeType]) -> _NodeType | None:
"""Get the closest parent node (including self) which is an instance of
the given class.
@ -418,7 +419,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
def _repr_failure_py(
self,
excinfo: ExceptionInfo[BaseException],
style: "Optional[_TracebackStyle]" = None,
style: Optional[_TracebackStyle] = None,
) -> TerminalRepr:
from _pytest.fixtures import FixtureLookupError
@ -426,26 +427,26 @@ class Node(abc.ABC, metaclass=NodeMeta):
excinfo = ExceptionInfo.from_exception(excinfo.value.cause)
if isinstance(excinfo.value, fail.Exception):
if not excinfo.value.pytrace:
style = "value"
style = 'value'
if isinstance(excinfo.value, FixtureLookupError):
return excinfo.value.formatrepr()
tbfilter: Union[bool, Callable[[ExceptionInfo[BaseException]], Traceback]]
if self.config.getoption("fulltrace", False):
style = "long"
tbfilter: bool | Callable[[ExceptionInfo[BaseException]], Traceback]
if self.config.getoption('fulltrace', False):
style = 'long'
tbfilter = False
else:
tbfilter = self._traceback_filter
if style == "auto":
style = "long"
if style == 'auto':
style = 'long'
# XXX should excinfo.getrepr record all data and toterminal() process it?
if style is None:
if self.config.getoption("tbstyle", "auto") == "short":
style = "short"
if self.config.getoption('tbstyle', 'auto') == 'short':
style = 'short'
else:
style = "long"
style = 'long'
if self.config.getoption("verbose", 0) > 1:
if self.config.getoption('verbose', 0) > 1:
truncate_locals = False
else:
truncate_locals = True
@ -464,7 +465,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
return excinfo.getrepr(
funcargs=True,
abspath=abspath,
showlocals=self.config.getoption("showlocals", False),
showlocals=self.config.getoption('showlocals', False),
style=style,
tbfilter=tbfilter,
truncate_locals=truncate_locals,
@ -473,8 +474,8 @@ class Node(abc.ABC, metaclass=NodeMeta):
def repr_failure(
self,
excinfo: ExceptionInfo[BaseException],
style: "Optional[_TracebackStyle]" = None,
) -> Union[str, TerminalRepr]:
style: Optional[_TracebackStyle] = None,
) -> str | TerminalRepr:
"""Return a representation of a collection or test failure.
.. seealso:: :ref:`non-python tests`
@ -484,7 +485,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
return self._repr_failure_py(excinfo, style)
def get_fslocation_from_item(node: "Node") -> Tuple[Union[str, Path], Optional[int]]:
def get_fslocation_from_item(node: Node) -> tuple[str | Path, int | None]:
"""Try to extract the actual location from a node, depending on available attributes:
* "location": a pair (path, lineno)
@ -494,13 +495,13 @@ def get_fslocation_from_item(node: "Node") -> Tuple[Union[str, Path], Optional[i
:rtype: A tuple of (str|Path, int) with filename and 0-based line number.
"""
# See Item.location.
location: Optional[Tuple[str, Optional[int], str]] = getattr(node, "location", None)
location: tuple[str, int | None, str] | None = getattr(node, 'location', None)
if location is not None:
return location[:2]
obj = getattr(node, "obj", None)
obj = getattr(node, 'obj', None)
if obj is not None:
return getfslineno(obj)
return getattr(node, "path", "unknown location"), -1
return getattr(node, 'path', 'unknown location'), -1
class Collector(Node, abc.ABC):
@ -514,34 +515,34 @@ class Collector(Node, abc.ABC):
"""An error during collection, contains a custom message."""
@abc.abstractmethod
def collect(self) -> Iterable[Union["Item", "Collector"]]:
def collect(self) -> Iterable[Item | Collector]:
"""Collect children (items and collectors) for this collector."""
raise NotImplementedError("abstract")
raise NotImplementedError('abstract')
# TODO: This omits the style= parameter which breaks Liskov Substitution.
def repr_failure( # type: ignore[override]
self, excinfo: ExceptionInfo[BaseException]
) -> Union[str, TerminalRepr]:
self, excinfo: ExceptionInfo[BaseException],
) -> str | TerminalRepr:
"""Return a representation of a collection failure.
:param excinfo: Exception information for the failure.
"""
if isinstance(excinfo.value, self.CollectError) and not self.config.getoption(
"fulltrace", False
'fulltrace', False,
):
exc = excinfo.value
return str(exc.args[0])
# Respect explicit tbstyle option, but default to "short"
# (_repr_failure_py uses "long" with "fulltrace" option always).
tbstyle = self.config.getoption("tbstyle", "auto")
if tbstyle == "auto":
tbstyle = "short"
tbstyle = self.config.getoption('tbstyle', 'auto')
if tbstyle == 'auto':
tbstyle = 'short'
return self._repr_failure_py(excinfo, style=tbstyle)
def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback:
if hasattr(self, "path"):
if hasattr(self, 'path'):
traceback = excinfo.traceback
ntraceback = traceback.cut(path=self.path)
if ntraceback == traceback:
@ -550,11 +551,11 @@ class Collector(Node, abc.ABC):
return excinfo.traceback
def _check_initialpaths_for_relpath(session: "Session", path: Path) -> Optional[str]:
def _check_initialpaths_for_relpath(session: Session, path: Path) -> str | None:
for initial_path in session._initialpaths:
if commonpath(path, initial_path) == initial_path:
rel = str(path.relative_to(initial_path))
return "" if rel == "." else rel
return '' if rel == '.' else rel
return None
@ -563,14 +564,14 @@ class FSCollector(Collector, abc.ABC):
def __init__(
self,
fspath: Optional[LEGACY_PATH] = None,
path_or_parent: Optional[Union[Path, Node]] = None,
path: Optional[Path] = None,
name: Optional[str] = None,
parent: Optional[Node] = None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
nodeid: Optional[str] = None,
fspath: LEGACY_PATH | None = None,
path_or_parent: Path | Node | None = None,
path: Path | None = None,
name: str | None = None,
parent: Node | None = None,
config: Config | None = None,
session: Session | None = None,
nodeid: str | None = None,
) -> None:
if path_or_parent:
if isinstance(path_or_parent, Node):
@ -620,10 +621,10 @@ class FSCollector(Collector, abc.ABC):
cls,
parent,
*,
fspath: Optional[LEGACY_PATH] = None,
path: Optional[Path] = None,
fspath: LEGACY_PATH | None = None,
path: Path | None = None,
**kw,
) -> "Self":
) -> Self:
"""The public constructor."""
return super().from_parent(parent=parent, fspath=fspath, path=path, **kw)
@ -665,9 +666,9 @@ class Item(Node, abc.ABC):
self,
name,
parent=None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
nodeid: Optional[str] = None,
config: Config | None = None,
session: Session | None = None,
nodeid: str | None = None,
**kw,
) -> None:
# The first two arguments are intentionally passed positionally,
@ -682,11 +683,11 @@ class Item(Node, abc.ABC):
nodeid=nodeid,
**kw,
)
self._report_sections: List[Tuple[str, str, str]] = []
self._report_sections: list[tuple[str, str, str]] = []
#: A list of tuples (name, value) that holds user defined properties
#: for this test.
self.user_properties: List[Tuple[str, object]] = []
self.user_properties: list[tuple[str, object]] = []
self._check_item_and_collector_diamond_inheritance()
@ -701,21 +702,21 @@ class Item(Node, abc.ABC):
# for the same class more than once, which is not helpful.
# It is a hack, but was deemed acceptable in order to avoid
# flooding the user in the common case.
attr_name = "_pytest_diamond_inheritance_warning_shown"
attr_name = '_pytest_diamond_inheritance_warning_shown'
if getattr(cls, attr_name, False):
return
setattr(cls, attr_name, True)
problems = ", ".join(
problems = ', '.join(
base.__name__ for base in cls.__bases__ if issubclass(base, Collector)
)
if problems:
warnings.warn(
f"{cls.__name__} is an Item subclass and should not be a collector, "
f"however its bases {problems} are collectors.\n"
"Please split the Collectors and the Item into separate node types.\n"
"Pytest Doc example: https://docs.pytest.org/en/latest/example/nonpython.html\n"
"example pull request on a plugin: https://github.com/asmeurer/pytest-flakes/pull/40/",
f'{cls.__name__} is an Item subclass and should not be a collector, '
f'however its bases {problems} are collectors.\n'
'Please split the Collectors and the Item into separate node types.\n'
'Pytest Doc example: https://docs.pytest.org/en/latest/example/nonpython.html\n'
'example pull request on a plugin: https://github.com/asmeurer/pytest-flakes/pull/40/',
PytestWarning,
)
@ -727,7 +728,7 @@ class Item(Node, abc.ABC):
.. seealso:: :ref:`non-python tests`
"""
raise NotImplementedError("runtest must be implemented by Item subclass")
raise NotImplementedError('runtest must be implemented by Item subclass')
def add_report_section(self, when: str, key: str, content: str) -> None:
"""Add a new report section, similar to what's done internally to add
@ -746,7 +747,7 @@ class Item(Node, abc.ABC):
if content:
self._report_sections.append((when, key, content))
def reportinfo(self) -> Tuple[Union["os.PathLike[str]", str], Optional[int], str]:
def reportinfo(self) -> tuple[os.PathLike[str] | str, int | None, str]:
"""Get location information for this item for test reports.
Returns a tuple with three elements:
@ -757,10 +758,10 @@ class Item(Node, abc.ABC):
.. seealso:: :ref:`non-python tests`
"""
return self.path, None, ""
return self.path, None, ''
@cached_property
def location(self) -> Tuple[str, Optional[int], str]:
def location(self) -> tuple[str, int | None, str]:
"""
Returns a tuple of ``(relfspath, lineno, testname)`` for this item
where ``relfspath`` is file path relative to ``config.rootpath``

View file

@ -1,5 +1,6 @@
"""Exception classes and constants handling test outcomes as well as
functions creating them."""
from __future__ import annotations
import sys
from typing import Any
@ -16,11 +17,11 @@ class OutcomeException(BaseException):
"""OutcomeException and its subclass instances indicate and contain info
about test and collection outcomes."""
def __init__(self, msg: Optional[str] = None, pytrace: bool = True) -> None:
def __init__(self, msg: str | None = None, pytrace: bool = True) -> None:
if msg is not None and not isinstance(msg, str):
error_msg = ( # type: ignore[unreachable]
"{} expected string as 'msg' parameter, got '{}' instead.\n"
"Perhaps you meant to use a mark?"
'Perhaps you meant to use a mark?'
)
raise TypeError(error_msg.format(type(self).__name__, type(msg).__name__))
super().__init__(msg)
@ -30,7 +31,7 @@ class OutcomeException(BaseException):
def __repr__(self) -> str:
if self.msg is not None:
return self.msg
return f"<{self.__class__.__name__} instance>"
return f'<{self.__class__.__name__} instance>'
__str__ = __repr__
@ -41,11 +42,11 @@ TEST_OUTCOME = (OutcomeException, Exception)
class Skipped(OutcomeException):
# XXX hackish: on 3k we fake to live in the builtins
# in order to have Skipped exception printing shorter/nicer
__module__ = "builtins"
__module__ = 'builtins'
def __init__(
self,
msg: Optional[str] = None,
msg: str | None = None,
pytrace: bool = True,
allow_module_level: bool = False,
*,
@ -61,14 +62,14 @@ class Skipped(OutcomeException):
class Failed(OutcomeException):
"""Raised from an explicit call to pytest.fail()."""
__module__ = "builtins"
__module__ = 'builtins'
class Exit(Exception):
"""Raised for immediate program exits (no tracebacks/summaries)."""
def __init__(
self, msg: str = "unknown reason", returncode: Optional[int] = None
self, msg: str = 'unknown reason', returncode: int | None = None,
) -> None:
self.msg = msg
self.returncode = returncode
@ -78,8 +79,8 @@ class Exit(Exception):
# Elaborate hack to work around https://github.com/python/mypy/issues/2087.
# Ideally would just be `exit.Exception = Exit` etc.
_F = TypeVar("_F", bound=Callable[..., object])
_ET = TypeVar("_ET", bound=Type[BaseException])
_F = TypeVar('_F', bound=Callable[..., object])
_ET = TypeVar('_ET', bound=Type[BaseException])
class _WithException(Protocol[_F, _ET]):
@ -101,8 +102,8 @@ def _with_exception(exception_type: _ET) -> Callable[[_F], _WithException[_F, _E
@_with_exception(Exit)
def exit(
reason: str = "",
returncode: Optional[int] = None,
reason: str = '',
returncode: int | None = None,
) -> NoReturn:
"""Exit testing process.
@ -119,7 +120,7 @@ def exit(
@_with_exception(Skipped)
def skip(
reason: str = "",
reason: str = '',
*,
allow_module_level: bool = False,
) -> NoReturn:
@ -152,7 +153,7 @@ def skip(
@_with_exception(Failed)
def fail(reason: str = "", pytrace: bool = True) -> NoReturn:
def fail(reason: str = '', pytrace: bool = True) -> NoReturn:
"""Explicitly fail an executing test with the given message.
:param reason:
@ -171,7 +172,7 @@ class XFailed(Failed):
@_with_exception(XFailed)
def xfail(reason: str = "") -> NoReturn:
def xfail(reason: str = '') -> NoReturn:
"""Imperatively xfail an executing test or setup function with the given reason.
This function should be called only during testing (setup, call or teardown).
@ -192,7 +193,7 @@ def xfail(reason: str = "") -> NoReturn:
def importorskip(
modname: str, minversion: Optional[str] = None, reason: Optional[str] = None
modname: str, minversion: str | None = None, reason: str | None = None,
) -> Any:
"""Import and return the requested module ``modname``, or skip the
current test if the module cannot be imported.
@ -216,30 +217,30 @@ def importorskip(
import warnings
__tracebackhide__ = True
compile(modname, "", "eval") # to catch syntaxerrors
compile(modname, '', 'eval') # to catch syntaxerrors
with warnings.catch_warnings():
# Make sure to ignore ImportWarnings that might happen because
# of existing directories with the same name we're trying to
# import but without a __init__.py file.
warnings.simplefilter("ignore")
warnings.simplefilter('ignore')
try:
__import__(modname)
except ImportError as exc:
if reason is None:
reason = f"could not import {modname!r}: {exc}"
reason = f'could not import {modname!r}: {exc}'
raise Skipped(reason, allow_module_level=True) from None
mod = sys.modules[modname]
if minversion is None:
return mod
verattr = getattr(mod, "__version__", None)
verattr = getattr(mod, '__version__', None)
if minversion is not None:
# Imported lazily to improve start-up time.
from packaging.version import Version
if verattr is None or Version(verattr) < Version(minversion):
raise Skipped(
f"module {modname!r} has __version__ {verattr!r}, required is: {minversion!r}",
f'module {modname!r} has __version__ {verattr!r}, required is: {minversion!r}',
allow_module_level=True,
)
return mod

View file

@ -1,50 +1,52 @@
# mypy: allow-untyped-defs
"""Submit failure or test session information to a pastebin service."""
from io import StringIO
from __future__ import annotations
import tempfile
from io import StringIO
from typing import IO
from typing import Union
import pytest
from _pytest.config import Config
from _pytest.config import create_terminal_writer
from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter
import pytest
pastebinfile_key = StashKey[IO[bytes]]()
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group = parser.getgroup('terminal reporting')
group._addoption(
"--pastebin",
metavar="mode",
action="store",
dest="pastebin",
'--pastebin',
metavar='mode',
action='store',
dest='pastebin',
default=None,
choices=["failed", "all"],
help="Send failed|all info to bpaste.net pastebin service",
choices=['failed', 'all'],
help='Send failed|all info to bpaste.net pastebin service',
)
@pytest.hookimpl(trylast=True)
def pytest_configure(config: Config) -> None:
if config.option.pastebin == "all":
tr = config.pluginmanager.getplugin("terminalreporter")
if config.option.pastebin == 'all':
tr = config.pluginmanager.getplugin('terminalreporter')
# If no terminal reporter plugin is present, nothing we can do here;
# this can happen when this function executes in a worker node
# when using pytest-xdist, for example.
if tr is not None:
# pastebin file will be UTF-8 encoded binary file.
config.stash[pastebinfile_key] = tempfile.TemporaryFile("w+b")
config.stash[pastebinfile_key] = tempfile.TemporaryFile('w+b')
oldwrite = tr._tw.write
def tee_write(s, **kwargs):
oldwrite(s, **kwargs)
if isinstance(s, str):
s = s.encode("utf-8")
s = s.encode('utf-8')
config.stash[pastebinfile_key].write(s)
tr._tw.write = tee_write
@ -59,15 +61,15 @@ def pytest_unconfigure(config: Config) -> None:
pastebinfile.close()
del config.stash[pastebinfile_key]
# Undo our patching in the terminal reporter.
tr = config.pluginmanager.getplugin("terminalreporter")
del tr._tw.__dict__["write"]
tr = config.pluginmanager.getplugin('terminalreporter')
del tr._tw.__dict__['write']
# Write summary.
tr.write_sep("=", "Sending information to Paste Service")
tr.write_sep('=', 'Sending information to Paste Service')
pastebinurl = create_new_paste(sessionlog)
tr.write_line("pastebin session-log: %s\n" % pastebinurl)
tr.write_line('pastebin session-log: %s\n' % pastebinurl)
def create_new_paste(contents: Union[str, bytes]) -> str:
def create_new_paste(contents: str | bytes) -> str:
"""Create a new paste using the bpaste.net service.
:contents: Paste contents string.
@ -77,27 +79,27 @@ def create_new_paste(contents: Union[str, bytes]) -> str:
from urllib.parse import urlencode
from urllib.request import urlopen
params = {"code": contents, "lexer": "text", "expiry": "1week"}
url = "https://bpa.st"
params = {'code': contents, 'lexer': 'text', 'expiry': '1week'}
url = 'https://bpa.st'
try:
response: str = (
urlopen(url, data=urlencode(params).encode("ascii")).read().decode("utf-8")
urlopen(url, data=urlencode(params).encode('ascii')).read().decode('utf-8')
)
except OSError as exc_info: # urllib errors
return "bad response: %s" % exc_info
return 'bad response: %s' % exc_info
m = re.search(r'href="/raw/(\w+)"', response)
if m:
return f"{url}/show/{m.group(1)}"
return f'{url}/show/{m.group(1)}'
else:
return "bad response: invalid format ('" + response + "')"
def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
if terminalreporter.config.option.pastebin != "failed":
if terminalreporter.config.option.pastebin != 'failed':
return
if "failed" in terminalreporter.stats:
terminalreporter.write_sep("=", "Sending information to Paste Service")
for rep in terminalreporter.stats["failed"]:
if 'failed' in terminalreporter.stats:
terminalreporter.write_sep('=', 'Sending information to Paste Service')
for rep in terminalreporter.stats['failed']:
try:
msg = rep.longrepr.reprtraceback.reprentries[-1].reprfileloc
except AttributeError:
@ -108,4 +110,4 @@ def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
s = file.getvalue()
assert len(s)
pastebinurl = create_new_paste(s)
terminalreporter.write_line(f"{msg} --> {pastebinurl}")
terminalreporter.write_line(f'{msg} --> {pastebinurl}')

View file

@ -1,16 +1,23 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import atexit
import contextlib
import fnmatch
import importlib.util
import itertools
import os
import shutil
import sys
import types
import uuid
import warnings
from enum import Enum
from errno import EBADF
from errno import ELOOP
from errno import ENOENT
from errno import ENOTDIR
import fnmatch
from functools import partial
import importlib.util
import itertools
import os
from os.path import expanduser
from os.path import expandvars
from os.path import isabs
@ -18,9 +25,6 @@ from os.path import sep
from pathlib import Path
from pathlib import PurePath
from posixpath import sep as posix_sep
import shutil
import sys
import types
from types import ModuleType
from typing import Callable
from typing import Dict
@ -33,8 +37,6 @@ from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
import uuid
import warnings
from _pytest.compat import assert_never
from _pytest.outcomes import skip
@ -44,7 +46,7 @@ from _pytest.warning_types import PytestWarning
LOCK_TIMEOUT = 60 * 60 * 24 * 3
_AnyPurePath = TypeVar("_AnyPurePath", bound=PurePath)
_AnyPurePath = TypeVar('_AnyPurePath', bound=PurePath)
# The following function, variables and comments were
# copied from cpython 3.9 Lib/pathlib.py file.
@ -60,22 +62,22 @@ _IGNORED_WINERRORS = (
def _ignore_error(exception):
return (
getattr(exception, "errno", None) in _IGNORED_ERRORS
or getattr(exception, "winerror", None) in _IGNORED_WINERRORS
getattr(exception, 'errno', None) in _IGNORED_ERRORS or
getattr(exception, 'winerror', None) in _IGNORED_WINERRORS
)
def get_lock_path(path: _AnyPurePath) -> _AnyPurePath:
return path.joinpath(".lock")
return path.joinpath('.lock')
def on_rm_rf_error(
func,
path: str,
excinfo: Union[
BaseException,
Tuple[Type[BaseException], BaseException, Optional[types.TracebackType]],
],
excinfo: (
BaseException |
tuple[type[BaseException], BaseException, types.TracebackType | None]
),
*,
start_path: Path,
) -> bool:
@ -95,7 +97,7 @@ def on_rm_rf_error(
if not isinstance(exc, PermissionError):
warnings.warn(
PytestWarning(f"(rm_rf) error removing {path}\n{type(exc)}: {exc}")
PytestWarning(f'(rm_rf) error removing {path}\n{type(exc)}: {exc}'),
)
return False
@ -103,8 +105,8 @@ def on_rm_rf_error(
if func not in (os.open,):
warnings.warn(
PytestWarning(
f"(rm_rf) unknown function {func} when removing {path}:\n{type(exc)}: {exc}"
)
f'(rm_rf) unknown function {func} when removing {path}:\n{type(exc)}: {exc}',
),
)
return False
@ -142,7 +144,7 @@ def ensure_extended_length_path(path: Path) -> Path:
On Windows, this function returns the extended-length absolute version of path.
On other platforms it returns path unchanged.
"""
if sys.platform.startswith("win32"):
if sys.platform.startswith('win32'):
path = path.resolve()
path = Path(get_extended_length_path_str(str(path)))
return path
@ -150,12 +152,12 @@ def ensure_extended_length_path(path: Path) -> Path:
def get_extended_length_path_str(path: str) -> str:
"""Convert a path to a Windows extended length path."""
long_path_prefix = "\\\\?\\"
unc_long_path_prefix = "\\\\?\\UNC\\"
long_path_prefix = '\\\\?\\'
unc_long_path_prefix = '\\\\?\\UNC\\'
if path.startswith((long_path_prefix, unc_long_path_prefix)):
return path
# UNC
if path.startswith("\\\\"):
if path.startswith('\\\\'):
return unc_long_path_prefix + path[2:]
return long_path_prefix + path
@ -171,7 +173,7 @@ def rm_rf(path: Path) -> None:
shutil.rmtree(str(path), onerror=onerror)
def find_prefixed(root: Path, prefix: str) -> Iterator["os.DirEntry[str]"]:
def find_prefixed(root: Path, prefix: str) -> Iterator[os.DirEntry[str]]:
"""Find all elements in root that begin with the prefix, case insensitive."""
l_prefix = prefix.lower()
for x in os.scandir(root):
@ -179,7 +181,7 @@ def find_prefixed(root: Path, prefix: str) -> Iterator["os.DirEntry[str]"]:
yield x
def extract_suffixes(iter: Iterable["os.DirEntry[str]"], prefix: str) -> Iterator[str]:
def extract_suffixes(iter: Iterable[os.DirEntry[str]], prefix: str) -> Iterator[str]:
"""Return the parts of the paths following the prefix.
:param iter: Iterator over path names.
@ -204,7 +206,7 @@ def parse_num(maybe_num) -> int:
def _force_symlink(
root: Path, target: Union[str, PurePath], link_to: Union[str, Path]
root: Path, target: str | PurePath, link_to: str | Path,
) -> None:
"""Helper to create the current symlink.
@ -231,18 +233,18 @@ def make_numbered_dir(root: Path, prefix: str, mode: int = 0o700) -> Path:
# try up to 10 times to create the folder
max_existing = max(map(parse_num, find_suffixes(root, prefix)), default=-1)
new_number = max_existing + 1
new_path = root.joinpath(f"{prefix}{new_number}")
new_path = root.joinpath(f'{prefix}{new_number}')
try:
new_path.mkdir(mode=mode)
except Exception:
pass
else:
_force_symlink(root, prefix + "current", new_path)
_force_symlink(root, prefix + 'current', new_path)
return new_path
else:
raise OSError(
"could not create numbered dir with prefix "
f"{prefix} in {root} after 10 tries"
'could not create numbered dir with prefix '
f'{prefix} in {root} after 10 tries',
)
@ -252,14 +254,14 @@ def create_cleanup_lock(p: Path) -> Path:
try:
fd = os.open(str(lock_path), os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
except FileExistsError as e:
raise OSError(f"cannot create lockfile in {p}") from e
raise OSError(f'cannot create lockfile in {p}') from e
else:
pid = os.getpid()
spid = str(pid).encode()
os.write(fd, spid)
os.close(fd)
if not lock_path.is_file():
raise OSError("lock path got renamed after successful creation")
raise OSError('lock path got renamed after successful creation')
return lock_path
@ -289,7 +291,7 @@ def maybe_delete_a_numbered_dir(path: Path) -> None:
lock_path = create_cleanup_lock(path)
parent = path.parent
garbage = parent.joinpath(f"garbage-{uuid.uuid4()}")
garbage = parent.joinpath(f'garbage-{uuid.uuid4()}')
path.rename(garbage)
rm_rf(garbage)
except OSError:
@ -362,14 +364,14 @@ def cleanup_dead_symlinks(root: Path):
def cleanup_numbered_dir(
root: Path, prefix: str, keep: int, consider_lock_dead_if_created_before: float
root: Path, prefix: str, keep: int, consider_lock_dead_if_created_before: float,
) -> None:
"""Cleanup for lock driven numbered directories."""
if not root.exists():
return
for path in cleanup_candidates(root, prefix, keep):
try_cleanup(path, consider_lock_dead_if_created_before)
for path in root.glob("garbage-*"):
for path in root.glob('garbage-*'):
try_cleanup(path, consider_lock_dead_if_created_before)
cleanup_dead_symlinks(root)
@ -417,7 +419,7 @@ def resolve_from_str(input: str, rootpath: Path) -> Path:
return rootpath.joinpath(input)
def fnmatch_ex(pattern: str, path: Union[str, "os.PathLike[str]"]) -> bool:
def fnmatch_ex(pattern: str, path: str | os.PathLike[str]) -> bool:
"""A port of FNMatcher from py.path.common which works with PurePath() instances.
The difference between this algorithm and PurePath.match() is that the
@ -436,7 +438,7 @@ def fnmatch_ex(pattern: str, path: Union[str, "os.PathLike[str]"]) -> bool:
* https://bugs.python.org/issue34731
"""
path = PurePath(path)
iswin32 = sys.platform.startswith("win")
iswin32 = sys.platform.startswith('win')
if iswin32 and sep not in pattern and posix_sep in pattern:
# Running on Windows, the pattern has no Windows path separators,
@ -449,11 +451,11 @@ def fnmatch_ex(pattern: str, path: Union[str, "os.PathLike[str]"]) -> bool:
else:
name = str(path)
if path.is_absolute() and not os.path.isabs(pattern):
pattern = f"*{os.sep}{pattern}"
pattern = f'*{os.sep}{pattern}'
return fnmatch.fnmatch(name, pattern)
def parts(s: str) -> Set[str]:
def parts(s: str) -> set[str]:
parts = s.split(sep)
return {sep.join(parts[: i + 1]) or sep for i in range(len(parts))}
@ -463,15 +465,15 @@ def symlink_or_skip(src, dst, **kwargs):
try:
os.symlink(str(src), str(dst), **kwargs)
except OSError as e:
skip(f"symlinks not supported: {e}")
skip(f'symlinks not supported: {e}')
class ImportMode(Enum):
"""Possible values for `mode` parameter of `import_path`."""
prepend = "prepend"
append = "append"
importlib = "importlib"
prepend = 'prepend'
append = 'append'
importlib = 'importlib'
class ImportPathMismatchError(ImportError):
@ -484,9 +486,9 @@ class ImportPathMismatchError(ImportError):
def import_path(
path: Union[str, "os.PathLike[str]"],
path: str | os.PathLike[str],
*,
mode: Union[str, ImportMode] = ImportMode.prepend,
mode: str | ImportMode = ImportMode.prepend,
root: Path,
consider_namespace_packages: bool,
) -> ModuleType:
@ -534,7 +536,7 @@ def import_path(
# without touching sys.path.
try:
pkg_root, module_name = resolve_pkg_root_and_module_name(
path, consider_namespace_packages=consider_namespace_packages
path, consider_namespace_packages=consider_namespace_packages,
)
except CouldNotResolvePathError:
pass
@ -544,7 +546,7 @@ def import_path(
return sys.modules[module_name]
mod = _import_module_using_spec(
module_name, path, pkg_root, insert_modules=False
module_name, path, pkg_root, insert_modules=False,
)
if mod is not None:
return mod
@ -556,7 +558,7 @@ def import_path(
return sys.modules[module_name]
mod = _import_module_using_spec(
module_name, path, path.parent, insert_modules=True
module_name, path, path.parent, insert_modules=True,
)
if mod is None:
raise ImportError(f"Can't find module {module_name} at location {path}")
@ -564,7 +566,7 @@ def import_path(
try:
pkg_root, module_name = resolve_pkg_root_and_module_name(
path, consider_namespace_packages=consider_namespace_packages
path, consider_namespace_packages=consider_namespace_packages,
)
except CouldNotResolvePathError:
pkg_root, module_name = path.parent, path.stem
@ -584,19 +586,19 @@ def import_path(
importlib.import_module(module_name)
mod = sys.modules[module_name]
if path.name == "__init__.py":
if path.name == '__init__.py':
return mod
ignore = os.environ.get("PY_IGNORE_IMPORTMISMATCH", "")
if ignore != "1":
ignore = os.environ.get('PY_IGNORE_IMPORTMISMATCH', '')
if ignore != '1':
module_file = mod.__file__
if module_file is None:
raise ImportPathMismatchError(module_name, module_file, path)
if module_file.endswith((".pyc", ".pyo")):
if module_file.endswith(('.pyc', '.pyo')):
module_file = module_file[:-1]
if module_file.endswith(os.sep + "__init__.py"):
module_file = module_file[: -(len(os.sep + "__init__.py"))]
if module_file.endswith(os.sep + '__init__.py'):
module_file = module_file[: -(len(os.sep + '__init__.py'))]
try:
is_same = _is_same(str(path), module_file)
@ -610,8 +612,8 @@ def import_path(
def _import_module_using_spec(
module_name: str, module_path: Path, module_location: Path, *, insert_modules: bool
) -> Optional[ModuleType]:
module_name: str, module_path: Path, module_location: Path, *, insert_modules: bool,
) -> ModuleType | None:
"""
Tries to import a module by its canonical name, path to the .py file, and its
parent location.
@ -645,7 +647,7 @@ def _import_module_using_spec(
# Implement a special _is_same function on Windows which returns True if the two filenames
# compare equal, to circumvent os.path.samefile returning False for mounts in UNC (#7678).
if sys.platform.startswith("win"):
if sys.platform.startswith('win'):
def _is_same(f1: str, f2: str) -> bool:
return Path(f1) == Path(f2) or os.path.samefile(f1, f2)
@ -663,7 +665,7 @@ def module_name_from_path(path: Path, root: Path) -> str:
For example: path="projects/src/tests/test_foo.py" and root="/projects", the
resulting module name will be "src.tests.test_foo".
"""
path = path.with_suffix("")
path = path.with_suffix('')
try:
relative_path = path.relative_to(root)
except ValueError:
@ -676,18 +678,18 @@ def module_name_from_path(path: Path, root: Path) -> str:
# Module name for packages do not contain the __init__ file, unless
# the `__init__.py` file is at the root.
if len(path_parts) >= 2 and path_parts[-1] == "__init__":
if len(path_parts) >= 2 and path_parts[-1] == '__init__':
path_parts = path_parts[:-1]
# Module names cannot contain ".", normalize them to "_". This prevents
# a directory having a "." in the name (".env.310" for example) causing extra intermediate modules.
# Also, important to replace "." at the start of paths, as those are considered relative imports.
path_parts = tuple(x.replace(".", "_") for x in path_parts)
path_parts = tuple(x.replace('.', '_') for x in path_parts)
return ".".join(path_parts)
return '.'.join(path_parts)
def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) -> None:
def insert_missing_modules(modules: dict[str, ModuleType], module_name: str) -> None:
"""
Used by ``import_path`` to create intermediate modules when using mode=importlib.
@ -695,10 +697,10 @@ def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) ->
to create empty modules "src" and "src.tests" after inserting "src.tests.test_foo",
otherwise "src.tests.test_foo" is not importable by ``__import__``.
"""
module_parts = module_name.split(".")
child_module: Union[ModuleType, None] = None
module: Union[ModuleType, None] = None
child_name: str = ""
module_parts = module_name.split('.')
child_module: ModuleType | None = None
module: ModuleType | None = None
child_name: str = ''
while module_name:
if module_name not in modules:
try:
@ -723,12 +725,12 @@ def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) ->
setattr(module, child_name, child_module)
modules[module_name] = module
# Keep track of the child module while moving up the tree.
child_module, child_name = module, module_name.rpartition(".")[-1]
child_module, child_name = module, module_name.rpartition('.')[-1]
module_parts.pop(-1)
module_name = ".".join(module_parts)
module_name = '.'.join(module_parts)
def resolve_package_path(path: Path) -> Optional[Path]:
def resolve_package_path(path: Path) -> Path | None:
"""Return the Python package path by looking for the last
directory upwards which still contains an __init__.py.
@ -737,7 +739,7 @@ def resolve_package_path(path: Path) -> Optional[Path]:
result = None
for parent in itertools.chain((path,), path.parents):
if parent.is_dir():
if not (parent / "__init__.py").is_file():
if not (parent / '__init__.py').is_file():
break
if not parent.name.isidentifier():
break
@ -746,8 +748,8 @@ def resolve_package_path(path: Path) -> Optional[Path]:
def resolve_pkg_root_and_module_name(
path: Path, *, consider_namespace_packages: bool = False
) -> Tuple[Path, str]:
path: Path, *, consider_namespace_packages: bool = False,
) -> tuple[Path, str]:
"""
Return the path to the directory of the root package that contains the
given Python file, and its module name:
@ -779,20 +781,20 @@ def resolve_pkg_root_and_module_name(
for parent in pkg_root.parents:
# If any of the parent paths has a __init__.py, it means it is not
# a namespace package (see the docs linked above).
if (parent / "__init__.py").is_file():
if (parent / '__init__.py').is_file():
break
if str(parent) in sys.path:
# Point the pkg_root to the root of the namespace package.
pkg_root = parent
break
names = list(path.with_suffix("").relative_to(pkg_root).parts)
if names[-1] == "__init__":
names = list(path.with_suffix('').relative_to(pkg_root).parts)
if names[-1] == '__init__':
names.pop()
module_name = ".".join(names)
module_name = '.'.join(names)
return pkg_root, module_name
raise CouldNotResolvePathError(f"Could not resolve for {path}")
raise CouldNotResolvePathError(f'Could not resolve for {path}')
class CouldNotResolvePathError(Exception):
@ -800,9 +802,9 @@ class CouldNotResolvePathError(Exception):
def scandir(
path: Union[str, "os.PathLike[str]"],
sort_key: Callable[["os.DirEntry[str]"], object] = lambda entry: entry.name,
) -> List["os.DirEntry[str]"]:
path: str | os.PathLike[str],
sort_key: Callable[[os.DirEntry[str]], object] = lambda entry: entry.name,
) -> list[os.DirEntry[str]]:
"""Scan a directory recursively, in breadth-first order.
The returned entries are sorted according to the given key.
@ -825,8 +827,8 @@ def scandir(
def visit(
path: Union[str, "os.PathLike[str]"], recurse: Callable[["os.DirEntry[str]"], bool]
) -> Iterator["os.DirEntry[str]"]:
path: str | os.PathLike[str], recurse: Callable[[os.DirEntry[str]], bool],
) -> Iterator[os.DirEntry[str]]:
"""Walk a directory recursively, in breadth-first order.
The `recurse` predicate determines whether a directory is recursed.
@ -840,7 +842,7 @@ def visit(
yield from visit(entry.path, recurse)
def absolutepath(path: Union[Path, str]) -> Path:
def absolutepath(path: Path | str) -> Path:
"""Convert a path to an absolute path using os.path.abspath.
Prefer this over Path.resolve() (see #6523).
@ -849,7 +851,7 @@ def absolutepath(path: Union[Path, str]) -> Path:
return Path(os.path.abspath(str(path)))
def commonpath(path1: Path, path2: Path) -> Optional[Path]:
def commonpath(path1: Path, path2: Path) -> Path | None:
"""Return the common part shared with the other path, or None if there is
no common part.

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,10 @@
"""Helper plugin for pytester; should not be loaded on its own."""
# This plugin contains assertions used by pytester. pytester cannot
# contain them itself, since it is imported by the `pytest` module,
# hence cannot be subject to assertion rewriting, which requires a
# module to not be already imported.
from __future__ import annotations
from typing import Dict
from typing import Optional
from typing import Sequence
@ -15,10 +16,10 @@ from _pytest.reports import TestReport
def assertoutcome(
outcomes: Tuple[
outcomes: tuple[
Sequence[TestReport],
Sequence[Union[CollectReport, TestReport]],
Sequence[Union[CollectReport, TestReport]],
Sequence[CollectReport | TestReport],
Sequence[CollectReport | TestReport],
],
passed: int = 0,
skipped: int = 0,
@ -28,49 +29,49 @@ def assertoutcome(
realpassed, realskipped, realfailed = outcomes
obtained = {
"passed": len(realpassed),
"skipped": len(realskipped),
"failed": len(realfailed),
'passed': len(realpassed),
'skipped': len(realskipped),
'failed': len(realfailed),
}
expected = {"passed": passed, "skipped": skipped, "failed": failed}
expected = {'passed': passed, 'skipped': skipped, 'failed': failed}
assert obtained == expected, outcomes
def assert_outcomes(
outcomes: Dict[str, int],
outcomes: dict[str, int],
passed: int = 0,
skipped: int = 0,
failed: int = 0,
errors: int = 0,
xpassed: int = 0,
xfailed: int = 0,
warnings: Optional[int] = None,
deselected: Optional[int] = None,
warnings: int | None = None,
deselected: int | None = None,
) -> None:
"""Assert that the specified outcomes appear with the respective
numbers (0 means it didn't occur) in the text output from a test run."""
__tracebackhide__ = True
obtained = {
"passed": outcomes.get("passed", 0),
"skipped": outcomes.get("skipped", 0),
"failed": outcomes.get("failed", 0),
"errors": outcomes.get("errors", 0),
"xpassed": outcomes.get("xpassed", 0),
"xfailed": outcomes.get("xfailed", 0),
'passed': outcomes.get('passed', 0),
'skipped': outcomes.get('skipped', 0),
'failed': outcomes.get('failed', 0),
'errors': outcomes.get('errors', 0),
'xpassed': outcomes.get('xpassed', 0),
'xfailed': outcomes.get('xfailed', 0),
}
expected = {
"passed": passed,
"skipped": skipped,
"failed": failed,
"errors": errors,
"xpassed": xpassed,
"xfailed": xfailed,
'passed': passed,
'skipped': skipped,
'failed': failed,
'errors': errors,
'xpassed': xpassed,
'xfailed': xfailed,
}
if warnings is not None:
obtained["warnings"] = outcomes.get("warnings", 0)
expected["warnings"] = warnings
obtained['warnings'] = outcomes.get('warnings', 0)
expected['warnings'] = warnings
if deselected is not None:
obtained["deselected"] = outcomes.get("deselected", 0)
expected["deselected"] = deselected
obtained['deselected'] = outcomes.get('deselected', 0)
expected['deselected'] = deselected
assert obtained == expected

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,12 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import math
import pprint
from collections.abc import Collection
from collections.abc import Sized
from decimal import Decimal
import math
from numbers import Complex
import pprint
from types import TracebackType
from typing import Any
from typing import Callable
@ -33,25 +35,25 @@ if TYPE_CHECKING:
def _compare_approx(
full_object: object,
message_data: Sequence[Tuple[str, str, str]],
message_data: Sequence[tuple[str, str, str]],
number_of_elements: int,
different_ids: Sequence[object],
max_abs_diff: float,
max_rel_diff: float,
) -> List[str]:
) -> list[str]:
message_list = list(message_data)
message_list.insert(0, ("Index", "Obtained", "Expected"))
message_list.insert(0, ('Index', 'Obtained', 'Expected'))
max_sizes = [0, 0, 0]
for index, obtained, expected in message_list:
max_sizes[0] = max(max_sizes[0], len(index))
max_sizes[1] = max(max_sizes[1], len(obtained))
max_sizes[2] = max(max_sizes[2], len(expected))
explanation = [
f"comparison failed. Mismatched elements: {len(different_ids)} / {number_of_elements}:",
f"Max absolute difference: {max_abs_diff}",
f"Max relative difference: {max_rel_diff}",
f'comparison failed. Mismatched elements: {len(different_ids)} / {number_of_elements}:',
f'Max absolute difference: {max_abs_diff}',
f'Max relative difference: {max_rel_diff}',
] + [
f"{indexes:<{max_sizes[0]}} | {obtained:<{max_sizes[1]}} | {expected:<{max_sizes[2]}}"
f'{indexes:<{max_sizes[0]}} | {obtained:<{max_sizes[1]}} | {expected:<{max_sizes[2]}}'
for indexes, obtained, expected in message_list
]
return explanation
@ -79,11 +81,11 @@ class ApproxBase:
def __repr__(self) -> str:
raise NotImplementedError
def _repr_compare(self, other_side: Any) -> List[str]:
def _repr_compare(self, other_side: Any) -> list[str]:
return [
"comparison failed",
f"Obtained: {other_side}",
f"Expected: {self}",
'comparison failed',
f'Obtained: {other_side}',
f'Expected: {self}',
]
def __eq__(self, actual) -> bool:
@ -94,7 +96,7 @@ class ApproxBase:
def __bool__(self):
__tracebackhide__ = True
raise AssertionError(
"approx() is not supported in a boolean context.\nDid you mean: `assert a == approx(b)`?"
'approx() is not supported in a boolean context.\nDid you mean: `assert a == approx(b)`?',
)
# Ignore type because of https://github.com/python/mypy/issues/4266.
@ -103,7 +105,7 @@ class ApproxBase:
def __ne__(self, actual) -> bool:
return not (actual == self)
def _approx_scalar(self, x) -> "ApproxScalar":
def _approx_scalar(self, x) -> ApproxScalar:
if isinstance(x, Decimal):
return ApproxDecimal(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
@ -138,16 +140,16 @@ class ApproxNumpy(ApproxBase):
def __repr__(self) -> str:
list_scalars = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
self._approx_scalar, self.expected.tolist(),
)
return f"approx({list_scalars!r})"
return f'approx({list_scalars!r})'
def _repr_compare(self, other_side: "ndarray") -> List[str]:
def _repr_compare(self, other_side: ndarray) -> list[str]:
import itertools
import math
def get_value_from_nested_list(
nested_list: List[Any], nd_index: Tuple[Any, ...]
nested_list: list[Any], nd_index: tuple[Any, ...],
) -> Any:
"""
Helper function to get the value out of a nested list, given an n-dimensional index.
@ -160,13 +162,13 @@ class ApproxNumpy(ApproxBase):
np_array_shape = self.expected.shape
approx_side_as_seq = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
self._approx_scalar, self.expected.tolist(),
)
if np_array_shape != other_side.shape:
return [
"Impossible to compare arrays with different shapes.",
f"Shapes: {np_array_shape} and {other_side.shape}",
'Impossible to compare arrays with different shapes.',
f'Shapes: {np_array_shape} and {other_side.shape}',
]
number_of_elements = self.expected.size
@ -238,9 +240,9 @@ class ApproxMapping(ApproxBase):
with numeric values (the keys can be anything)."""
def __repr__(self) -> str:
return f"approx({({k: self._approx_scalar(v) for k, v in self.expected.items()})!r})"
return f'approx({({k: self._approx_scalar(v) for k, v in self.expected.items()})!r})'
def _repr_compare(self, other_side: Mapping[object, float]) -> List[str]:
def _repr_compare(self, other_side: Mapping[object, float]) -> list[str]:
import math
approx_side_as_map = {
@ -252,12 +254,12 @@ class ApproxMapping(ApproxBase):
max_rel_diff = -math.inf
different_ids = []
for (approx_key, approx_value), other_value in zip(
approx_side_as_map.items(), other_side.values()
approx_side_as_map.items(), other_side.values(),
):
if approx_value != other_value:
if approx_value.expected is not None and other_value is not None:
max_abs_diff = max(
max_abs_diff, abs(approx_value.expected - other_value)
max_abs_diff, abs(approx_value.expected - other_value),
)
if approx_value.expected == 0.0:
max_rel_diff = math.inf
@ -265,8 +267,8 @@ class ApproxMapping(ApproxBase):
max_rel_diff = max(
max_rel_diff,
abs(
(approx_value.expected - other_value)
/ approx_value.expected
(approx_value.expected - other_value) /
approx_value.expected,
),
)
different_ids.append(approx_key)
@ -302,7 +304,7 @@ class ApproxMapping(ApproxBase):
__tracebackhide__ = True
for key, value in self.expected.items():
if isinstance(value, type(self.expected)):
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
msg = 'pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}'
raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
@ -313,15 +315,15 @@ class ApproxSequenceLike(ApproxBase):
seq_type = type(self.expected)
if seq_type not in (tuple, list):
seq_type = list
return f"approx({seq_type(self._approx_scalar(x) for x in self.expected)!r})"
return f'approx({seq_type(self._approx_scalar(x) for x in self.expected)!r})'
def _repr_compare(self, other_side: Sequence[float]) -> List[str]:
def _repr_compare(self, other_side: Sequence[float]) -> list[str]:
import math
if len(self.expected) != len(other_side):
return [
"Impossible to compare lists with different sizes.",
f"Lengths: {len(self.expected)} and {len(other_side)}",
'Impossible to compare lists with different sizes.',
f'Lengths: {len(self.expected)} and {len(other_side)}',
]
approx_side_as_map = _recursive_sequence_map(self._approx_scalar, self.expected)
@ -331,7 +333,7 @@ class ApproxSequenceLike(ApproxBase):
max_rel_diff = -math.inf
different_ids = []
for i, (approx_value, other_value) in enumerate(
zip(approx_side_as_map, other_side)
zip(approx_side_as_map, other_side),
):
if approx_value != other_value:
abs_diff = abs(approx_value.expected - other_value)
@ -371,7 +373,7 @@ class ApproxSequenceLike(ApproxBase):
__tracebackhide__ = True
for index, x in enumerate(self.expected):
if isinstance(x, type(self.expected)):
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
msg = 'pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}'
raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
@ -380,8 +382,8 @@ class ApproxScalar(ApproxBase):
# Using Real should be better than this Union, but not possible yet:
# https://github.com/python/typeshed/pull/3108
DEFAULT_ABSOLUTE_TOLERANCE: Union[float, Decimal] = 1e-12
DEFAULT_RELATIVE_TOLERANCE: Union[float, Decimal] = 1e-6
DEFAULT_ABSOLUTE_TOLERANCE: float | Decimal = 1e-12
DEFAULT_RELATIVE_TOLERANCE: float | Decimal = 1e-6
def __repr__(self) -> str:
"""Return a string communicating both the expected value and the
@ -393,24 +395,24 @@ class ApproxScalar(ApproxBase):
# tolerances, i.e. non-numerics and infinities. Need to call abs to
# handle complex numbers, e.g. (inf + 1j).
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
abs(self.expected) # type: ignore[arg-type]
abs(self.expected), # type: ignore[arg-type]
):
return str(self.expected)
# If a sensible tolerance can't be calculated, self.tolerance will
# raise a ValueError. In this case, display '???'.
try:
vetted_tolerance = f"{self.tolerance:.1e}"
vetted_tolerance = f'{self.tolerance:.1e}'
if (
isinstance(self.expected, Complex)
and self.expected.imag
and not math.isinf(self.tolerance)
isinstance(self.expected, Complex) and
self.expected.imag and
not math.isinf(self.tolerance)
):
vetted_tolerance += " ∠ ±180°"
vetted_tolerance += ' ∠ ±180°'
except ValueError:
vetted_tolerance = "???"
vetted_tolerance = '???'
return f"{self.expected} ± {vetted_tolerance}"
return f'{self.expected} ± {vetted_tolerance}'
def __eq__(self, actual) -> bool:
"""Return whether the given value is equal to the expected value
@ -429,8 +431,8 @@ class ApproxScalar(ApproxBase):
# NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined.
if not (
isinstance(self.expected, (Complex, Decimal))
and isinstance(actual, (Complex, Decimal))
isinstance(self.expected, (Complex, Decimal)) and
isinstance(actual, (Complex, Decimal))
):
return False
@ -473,7 +475,7 @@ class ApproxScalar(ApproxBase):
if absolute_tolerance < 0:
raise ValueError(
f"absolute tolerance can't be negative: {absolute_tolerance}"
f"absolute tolerance can't be negative: {absolute_tolerance}",
)
if math.isnan(absolute_tolerance):
raise ValueError("absolute tolerance can't be NaN.")
@ -490,12 +492,12 @@ class ApproxScalar(ApproxBase):
# because we don't want to raise errors about the relative tolerance if
# we aren't even going to use it.
relative_tolerance = set_default(
self.rel, self.DEFAULT_RELATIVE_TOLERANCE
self.rel, self.DEFAULT_RELATIVE_TOLERANCE,
) * abs(self.expected)
if relative_tolerance < 0:
raise ValueError(
f"relative tolerance can't be negative: {relative_tolerance}"
f"relative tolerance can't be negative: {relative_tolerance}",
)
if math.isnan(relative_tolerance):
raise ValueError("relative tolerance can't be NaN.")
@ -507,8 +509,8 @@ class ApproxScalar(ApproxBase):
class ApproxDecimal(ApproxScalar):
"""Perform approximate comparisons where the expected value is a Decimal."""
DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12")
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
DEFAULT_ABSOLUTE_TOLERANCE = Decimal('1e-12')
DEFAULT_RELATIVE_TOLERANCE = Decimal('1e-6')
def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
@ -711,20 +713,20 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
__tracebackhide__ = True
if isinstance(expected, Decimal):
cls: Type[ApproxBase] = ApproxDecimal
cls: type[ApproxBase] = ApproxDecimal
elif isinstance(expected, Mapping):
cls = ApproxMapping
elif _is_numpy_array(expected):
expected = _as_numpy_array(expected)
cls = ApproxNumpy
elif (
hasattr(expected, "__getitem__")
and isinstance(expected, Sized)
and not isinstance(expected, (str, bytes))
hasattr(expected, '__getitem__') and
isinstance(expected, Sized) and
not isinstance(expected, (str, bytes))
):
cls = ApproxSequenceLike
elif isinstance(expected, Collection) and not isinstance(expected, (str, bytes)):
msg = f"pytest.approx() only supports ordered sequences, but got: {expected!r}"
msg = f'pytest.approx() only supports ordered sequences, but got: {expected!r}'
raise TypeError(msg)
else:
cls = ApproxScalar
@ -740,42 +742,42 @@ def _is_numpy_array(obj: object) -> bool:
return _as_numpy_array(obj) is not None
def _as_numpy_array(obj: object) -> Optional["ndarray"]:
def _as_numpy_array(obj: object) -> ndarray | None:
"""
Return an ndarray if the given object is implicitly convertible to ndarray,
and numpy is already imported, otherwise None.
"""
import sys
np: Any = sys.modules.get("numpy")
np: Any = sys.modules.get('numpy')
if np is not None:
# avoid infinite recursion on numpy scalars, which have __array__
if np.isscalar(obj):
return None
elif isinstance(obj, np.ndarray):
return obj
elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"):
elif hasattr(obj, '__array__') or hasattr('obj', '__array_interface__'):
return np.asarray(obj)
return None
# builtin pytest.raises helper
E = TypeVar("E", bound=BaseException)
E = TypeVar('E', bound=BaseException)
@overload
def raises(
expected_exception: Union[Type[E], Tuple[Type[E], ...]],
expected_exception: type[E] | tuple[type[E], ...],
*,
match: Optional[Union[str, Pattern[str]]] = ...,
) -> "RaisesContext[E]":
match: str | Pattern[str] | None = ...,
) -> RaisesContext[E]:
...
@overload
def raises(
expected_exception: Union[Type[E], Tuple[Type[E], ...]],
expected_exception: type[E] | tuple[type[E], ...],
func: Callable[..., Any],
*args: Any,
**kwargs: Any,
@ -784,8 +786,8 @@ def raises(
def raises(
expected_exception: Union[Type[E], Tuple[Type[E], ...]], *args: Any, **kwargs: Any
) -> Union["RaisesContext[E]", _pytest._code.ExceptionInfo[E]]:
expected_exception: type[E] | tuple[type[E], ...], *args: Any, **kwargs: Any,
) -> RaisesContext[E] | _pytest._code.ExceptionInfo[E]:
r"""Assert that a code block/function call raises an exception type, or one of its subclasses.
:param expected_exception:
@ -928,34 +930,34 @@ def raises(
if not expected_exception:
raise ValueError(
f"Expected an exception type or a tuple of exception types, but got `{expected_exception!r}`. "
f'Expected an exception type or a tuple of exception types, but got `{expected_exception!r}`. '
f"Raising exceptions is already understood as failing the test, so you don't need "
f"any special code to say 'this should never raise an exception'."
f"any special code to say 'this should never raise an exception'.",
)
if isinstance(expected_exception, type):
expected_exceptions: Tuple[Type[E], ...] = (expected_exception,)
expected_exceptions: tuple[type[E], ...] = (expected_exception,)
else:
expected_exceptions = expected_exception
for exc in expected_exceptions:
if not isinstance(exc, type) or not issubclass(exc, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
msg = 'expected exception must be a BaseException type, not {}' # type: ignore[unreachable]
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
raise TypeError(msg.format(not_a))
message = f"DID NOT RAISE {expected_exception}"
message = f'DID NOT RAISE {expected_exception}'
if not args:
match: Optional[Union[str, Pattern[str]]] = kwargs.pop("match", None)
match: str | Pattern[str] | None = kwargs.pop('match', None)
if kwargs:
msg = "Unexpected keyword arguments passed to pytest.raises: "
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
msg = 'Unexpected keyword arguments passed to pytest.raises: '
msg += ', '.join(sorted(kwargs))
msg += '\nUse context-manager form instead?'
raise TypeError(msg)
return RaisesContext(expected_exception, message, match)
else:
func = args[0]
if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
raise TypeError(f'{func!r} object (type: {type(func)}) must be callable')
try:
func(*args[1:], **kwargs)
except expected_exception as e:
@ -971,14 +973,14 @@ raises.Exception = fail.Exception # type: ignore
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__(
self,
expected_exception: Union[Type[E], Tuple[Type[E], ...]],
expected_exception: type[E] | tuple[type[E], ...],
message: str,
match_expr: Optional[Union[str, Pattern[str]]] = None,
match_expr: str | Pattern[str] | None = None,
) -> None:
self.expected_exception = expected_exception
self.message = message
self.match_expr = match_expr
self.excinfo: Optional[_pytest._code.ExceptionInfo[E]] = None
self.excinfo: _pytest._code.ExceptionInfo[E] | None = None
def __enter__(self) -> _pytest._code.ExceptionInfo[E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later()
@ -986,9 +988,9 @@ class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
__tracebackhide__ = True
if exc_type is None:

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import sys
import pytest
@ -6,19 +8,19 @@ from pytest import Parser
def pytest_addoption(parser: Parser) -> None:
parser.addini("pythonpath", type="paths", help="Add paths to sys.path", default=[])
parser.addini('pythonpath', type='paths', help='Add paths to sys.path', default=[])
@pytest.hookimpl(tryfirst=True)
def pytest_load_initial_conftests(early_config: Config) -> None:
# `pythonpath = a b` will set `sys.path` to `[a, b, x, y, z, ...]`
for path in reversed(early_config.getini("pythonpath")):
for path in reversed(early_config.getini('pythonpath')):
sys.path.insert(0, str(path))
@pytest.hookimpl(trylast=True)
def pytest_unconfigure(config: Config) -> None:
for path in config.getini("pythonpath"):
for path in config.getini('pythonpath'):
path_str = str(path)
if path_str in sys.path:
sys.path.remove(path_str)

View file

@ -1,7 +1,10 @@
# mypy: allow-untyped-defs
"""Record warnings during test function execution."""
from pprint import pformat
from __future__ import annotations
import re
import warnings
from pprint import pformat
from types import TracebackType
from typing import Any
from typing import Callable
@ -16,7 +19,6 @@ from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
import warnings
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
@ -24,11 +26,11 @@ from _pytest.outcomes import Exit
from _pytest.outcomes import fail
T = TypeVar("T")
T = TypeVar('T')
@fixture
def recwarn() -> Generator["WarningsRecorder", None, None]:
def recwarn() -> Generator[WarningsRecorder, None, None]:
"""Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions.
See https://docs.pytest.org/en/latest/how-to/capture-warnings.html for information
@ -36,14 +38,14 @@ def recwarn() -> Generator["WarningsRecorder", None, None]:
"""
wrec = WarningsRecorder(_ispytest=True)
with wrec:
warnings.simplefilter("default")
warnings.simplefilter('default')
yield wrec
@overload
def deprecated_call(
*, match: Optional[Union[str, Pattern[str]]] = ...
) -> "WarningsRecorder":
*, match: str | Pattern[str] | None = ...,
) -> WarningsRecorder:
...
@ -53,8 +55,8 @@ def deprecated_call(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
def deprecated_call(
func: Optional[Callable[..., Any]] = None, *args: Any, **kwargs: Any
) -> Union["WarningsRecorder", Any]:
func: Callable[..., Any] | None = None, *args: Any, **kwargs: Any,
) -> WarningsRecorder | Any:
"""Assert that code produces a ``DeprecationWarning`` or ``PendingDeprecationWarning`` or ``FutureWarning``.
This function can be used as a context manager::
@ -82,22 +84,22 @@ def deprecated_call(
if func is not None:
args = (func, *args)
return warns(
(DeprecationWarning, PendingDeprecationWarning, FutureWarning), *args, **kwargs
(DeprecationWarning, PendingDeprecationWarning, FutureWarning), *args, **kwargs,
)
@overload
def warns(
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = ...,
expected_warning: type[Warning] | tuple[type[Warning], ...] = ...,
*,
match: Optional[Union[str, Pattern[str]]] = ...,
) -> "WarningsChecker":
match: str | Pattern[str] | None = ...,
) -> WarningsChecker:
...
@overload
def warns(
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]],
expected_warning: type[Warning] | tuple[type[Warning], ...],
func: Callable[..., T],
*args: Any,
**kwargs: Any,
@ -106,11 +108,11 @@ def warns(
def warns(
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = Warning,
expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning,
*args: Any,
match: Optional[Union[str, Pattern[str]]] = None,
match: str | Pattern[str] | None = None,
**kwargs: Any,
) -> Union["WarningsChecker", Any]:
) -> WarningsChecker | Any:
r"""Assert that code raises a particular class of warning.
Specifically, the parameter ``expected_warning`` can be a warning class or tuple
@ -155,16 +157,16 @@ def warns(
__tracebackhide__ = True
if not args:
if kwargs:
argnames = ", ".join(sorted(kwargs))
argnames = ', '.join(sorted(kwargs))
raise TypeError(
f"Unexpected keyword arguments passed to pytest.warns: {argnames}"
"\nUse context-manager form instead?"
f'Unexpected keyword arguments passed to pytest.warns: {argnames}'
'\nUse context-manager form instead?',
)
return WarningsChecker(expected_warning, match_expr=match, _ispytest=True)
else:
func = args[0]
if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
raise TypeError(f'{func!r} object (type: {type(func)}) must be callable')
with WarningsChecker(expected_warning, _ispytest=True):
return func(*args[1:], **kwargs)
@ -187,18 +189,18 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
# Type ignored due to the way typeshed handles warnings.catch_warnings.
super().__init__(record=True) # type: ignore[call-arg]
self._entered = False
self._list: List[warnings.WarningMessage] = []
self._list: list[warnings.WarningMessage] = []
@property
def list(self) -> List["warnings.WarningMessage"]:
def list(self) -> list[warnings.WarningMessage]:
"""The list of recorded warnings."""
return self._list
def __getitem__(self, i: int) -> "warnings.WarningMessage":
def __getitem__(self, i: int) -> warnings.WarningMessage:
"""Get a recorded warning by index."""
return self._list[i]
def __iter__(self) -> Iterator["warnings.WarningMessage"]:
def __iter__(self) -> Iterator[warnings.WarningMessage]:
"""Iterate through the recorded warnings."""
return iter(self._list)
@ -206,24 +208,24 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
"""The number of recorded warnings."""
return len(self._list)
def pop(self, cls: Type[Warning] = Warning) -> "warnings.WarningMessage":
def pop(self, cls: type[Warning] = Warning) -> warnings.WarningMessage:
"""Pop the first recorded warning which is an instance of ``cls``,
but not an instance of a child class of any other match.
Raises ``AssertionError`` if there is no match.
"""
best_idx: Optional[int] = None
best_idx: int | None = None
for i, w in enumerate(self._list):
if w.category == cls:
return self._list.pop(i) # exact match, stop looking
if issubclass(w.category, cls) and (
best_idx is None
or not issubclass(w.category, self._list[best_idx].category)
best_idx is None or
not issubclass(w.category, self._list[best_idx].category)
):
best_idx = i
if best_idx is not None:
return self._list.pop(best_idx)
__tracebackhide__ = True
raise AssertionError(f"{cls!r} not found in warning list")
raise AssertionError(f'{cls!r} not found in warning list')
def clear(self) -> None:
"""Clear the list of recorded warnings."""
@ -231,26 +233,26 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
# -- it returns a List but we only emulate one.
def __enter__(self) -> "WarningsRecorder": # type: ignore
def __enter__(self) -> WarningsRecorder: # type: ignore
if self._entered:
__tracebackhide__ = True
raise RuntimeError(f"Cannot enter {self!r} twice")
raise RuntimeError(f'Cannot enter {self!r} twice')
_list = super().__enter__()
# record=True means it's None.
assert _list is not None
self._list = _list
warnings.simplefilter("always")
warnings.simplefilter('always')
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if not self._entered:
__tracebackhide__ = True
raise RuntimeError(f"Cannot exit {self!r} without entering first")
raise RuntimeError(f'Cannot exit {self!r} without entering first')
super().__exit__(exc_type, exc_val, exc_tb)
@ -263,22 +265,22 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
class WarningsChecker(WarningsRecorder):
def __init__(
self,
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = Warning,
match_expr: Optional[Union[str, Pattern[str]]] = None,
expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning,
match_expr: str | Pattern[str] | None = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
super().__init__(_ispytest=True)
msg = "exceptions must be derived from Warning, not %s"
msg = 'exceptions must be derived from Warning, not %s'
if isinstance(expected_warning, tuple):
for exc in expected_warning:
if not issubclass(exc, Warning):
raise TypeError(msg % type(exc))
expected_warning_tup = expected_warning
elif isinstance(expected_warning, type) and issubclass(
expected_warning, Warning
expected_warning, Warning,
):
expected_warning_tup = (expected_warning,)
else:
@ -290,14 +292,14 @@ class WarningsChecker(WarningsRecorder):
def matches(self, warning: warnings.WarningMessage) -> bool:
assert self.expected_warning is not None
return issubclass(warning.category, self.expected_warning) and bool(
self.match_expr is None or re.search(self.match_expr, str(warning.message))
self.match_expr is None or re.search(self.match_expr, str(warning.message)),
)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
super().__exit__(exc_type, exc_val, exc_tb)
@ -308,9 +310,9 @@ class WarningsChecker(WarningsRecorder):
# when the warning doesn't happen. Control-flow exceptions should always
# propagate.
if exc_val is not None and (
not isinstance(exc_val, Exception)
not isinstance(exc_val, Exception) or
# Exit is an Exception, not a BaseException, for some reason.
or isinstance(exc_val, Exit)
isinstance(exc_val, Exit)
):
return
@ -320,14 +322,14 @@ class WarningsChecker(WarningsRecorder):
try:
if not any(issubclass(w.category, self.expected_warning) for w in self):
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n"
f" Emitted warnings: {found_str()}."
f'DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n'
f' Emitted warnings: {found_str()}.',
)
elif not any(self.matches(w) for w in self):
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.\n"
f" Regex: {self.match_expr}\n"
f" Emitted warnings: {found_str()}."
f'DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.\n'
f' Regex: {self.match_expr}\n'
f' Emitted warnings: {found_str()}.',
)
finally:
# Whether or not any warnings matched, we want to re-emit all unmatched warnings.
@ -366,5 +368,5 @@ class WarningsChecker(WarningsRecorder):
# its first argument was not a string. But that case can't be
# distinguished from an invalid type.
raise TypeError(
f"Warning must be str or Warning, got {msg!r} (type {type(msg).__name__})"
f'Warning must be str or Warning, got {msg!r} (type {type(msg).__name__})',
)

View file

@ -1,7 +1,9 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import dataclasses
from io import StringIO
import os
from io import StringIO
from pprint import pprint
from typing import Any
from typing import cast
@ -47,25 +49,25 @@ def getworkerinfoline(node):
return node._workerinfocache
except AttributeError:
d = node.workerinfo
ver = "{}.{}.{}".format(*d["version_info"][:3])
node._workerinfocache = s = "[{}] {} -- Python {} {}".format(
d["id"], d["sysplatform"], ver, d["executable"]
ver = '{}.{}.{}'.format(*d['version_info'][:3])
node._workerinfocache = s = '[{}] {} -- Python {} {}'.format(
d['id'], d['sysplatform'], ver, d['executable'],
)
return s
_R = TypeVar("_R", bound="BaseReport")
_R = TypeVar('_R', bound='BaseReport')
class BaseReport:
when: Optional[str]
location: Optional[Tuple[str, Optional[int], str]]
longrepr: Union[
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr
]
sections: List[Tuple[str, str]]
when: str | None
location: tuple[str, int | None, str] | None
longrepr: (
None | ExceptionInfo[BaseException] | tuple[str, int, str] | str | TerminalRepr
)
sections: list[tuple[str, str]]
nodeid: str
outcome: Literal["passed", "failed", "skipped"]
outcome: Literal['passed', 'failed', 'skipped']
def __init__(self, **kw: Any) -> None:
self.__dict__.update(kw)
@ -76,7 +78,7 @@ class BaseReport:
...
def toterminal(self, out: TerminalWriter) -> None:
if hasattr(self, "node"):
if hasattr(self, 'node'):
worker_info = getworkerinfoline(self.node)
if worker_info:
out.line(worker_info)
@ -85,17 +87,17 @@ class BaseReport:
if longrepr is None:
return
if hasattr(longrepr, "toterminal"):
if hasattr(longrepr, 'toterminal'):
longrepr_terminal = cast(TerminalRepr, longrepr)
longrepr_terminal.toterminal(out)
else:
try:
s = str(longrepr)
except UnicodeEncodeError:
s = "<unprintable longrepr>"
s = '<unprintable longrepr>'
out.line(s)
def get_sections(self, prefix: str) -> Iterator[Tuple[str, str]]:
def get_sections(self, prefix: str) -> Iterator[tuple[str, str]]:
for name, content in self.sections:
if name.startswith(prefix):
yield prefix, content
@ -120,8 +122,8 @@ class BaseReport:
.. versionadded:: 3.5
"""
return "\n".join(
content for (prefix, content) in self.get_sections("Captured log")
return '\n'.join(
content for (prefix, content) in self.get_sections('Captured log')
)
@property
@ -130,8 +132,8 @@ class BaseReport:
.. versionadded:: 3.0
"""
return "".join(
content for (prefix, content) in self.get_sections("Captured stdout")
return ''.join(
content for (prefix, content) in self.get_sections('Captured stdout')
)
@property
@ -140,29 +142,29 @@ class BaseReport:
.. versionadded:: 3.0
"""
return "".join(
content for (prefix, content) in self.get_sections("Captured stderr")
return ''.join(
content for (prefix, content) in self.get_sections('Captured stderr')
)
@property
def passed(self) -> bool:
"""Whether the outcome is passed."""
return self.outcome == "passed"
return self.outcome == 'passed'
@property
def failed(self) -> bool:
"""Whether the outcome is failed."""
return self.outcome == "failed"
return self.outcome == 'failed'
@property
def skipped(self) -> bool:
"""Whether the outcome is skipped."""
return self.outcome == "skipped"
return self.outcome == 'skipped'
@property
def fspath(self) -> str:
"""The path portion of the reported node, as a string."""
return self.nodeid.split("::")[0]
return self.nodeid.split('::')[0]
@property
def count_towards_summary(self) -> bool:
@ -177,7 +179,7 @@ class BaseReport:
return True
@property
def head_line(self) -> Optional[str]:
def head_line(self) -> str | None:
"""**Experimental** The head line shown with longrepr output for this
report, more commonly during traceback representation during
failures::
@ -199,11 +201,11 @@ class BaseReport:
def _get_verbose_word(self, config: Config):
_category, _short, verbose = config.hook.pytest_report_teststatus(
report=self, config=config
report=self, config=config,
)
return verbose
def _to_json(self) -> Dict[str, Any]:
def _to_json(self) -> dict[str, Any]:
"""Return the contents of this report as a dict of builtin entries,
suitable for serialization.
@ -214,7 +216,7 @@ class BaseReport:
return _report_to_json(self)
@classmethod
def _from_json(cls: Type[_R], reportdict: Dict[str, object]) -> _R:
def _from_json(cls: type[_R], reportdict: dict[str, object]) -> _R:
"""Create either a TestReport or CollectReport, depending on the calling class.
It is the callers responsibility to know which class to pass here.
@ -228,16 +230,16 @@ class BaseReport:
def _report_unserialization_failure(
type_name: str, report_class: Type[BaseReport], reportdict
type_name: str, report_class: type[BaseReport], reportdict,
) -> NoReturn:
url = "https://github.com/pytest-dev/pytest/issues"
url = 'https://github.com/pytest-dev/pytest/issues'
stream = StringIO()
pprint("-" * 100, stream=stream)
pprint("INTERNALERROR: Unknown entry type returned: %s" % type_name, stream=stream)
pprint("report_name: %s" % report_class, stream=stream)
pprint('-' * 100, stream=stream)
pprint('INTERNALERROR: Unknown entry type returned: %s' % type_name, stream=stream)
pprint('report_name: %s' % report_class, stream=stream)
pprint(reportdict, stream=stream)
pprint("Please report this bug at %s" % url, stream=stream)
pprint("-" * 100, stream=stream)
pprint('Please report this bug at %s' % url, stream=stream)
pprint('-' * 100, stream=stream)
raise RuntimeError(stream.getvalue())
@ -257,18 +259,18 @@ class TestReport(BaseReport):
def __init__(
self,
nodeid: str,
location: Tuple[str, Optional[int], str],
location: tuple[str, int | None, str],
keywords: Mapping[str, Any],
outcome: Literal["passed", "failed", "skipped"],
longrepr: Union[
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr
],
when: Literal["setup", "call", "teardown"],
sections: Iterable[Tuple[str, str]] = (),
outcome: Literal['passed', 'failed', 'skipped'],
longrepr: (
None | ExceptionInfo[BaseException] | tuple[str, int, str] | str | TerminalRepr
),
when: Literal['setup', 'call', 'teardown'],
sections: Iterable[tuple[str, str]] = (),
duration: float = 0,
start: float = 0,
stop: float = 0,
user_properties: Optional[Iterable[Tuple[str, object]]] = None,
user_properties: Iterable[tuple[str, object]] | None = None,
**extra,
) -> None:
#: Normalized collection nodeid.
@ -279,7 +281,7 @@ class TestReport(BaseReport):
#: collected one e.g. if a method is inherited from a different module.
#: The filesystempath may be relative to ``config.rootdir``.
#: The line number is 0-based.
self.location: Tuple[str, Optional[int], str] = location
self.location: tuple[str, int | None, str] = location
#: A name -> value dictionary containing all keywords and
#: markers associated with a test invocation.
@ -315,10 +317,10 @@ class TestReport(BaseReport):
self.__dict__.update(extra)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self.nodeid!r} when={self.when!r} outcome={self.outcome!r}>"
return f'<{self.__class__.__name__} {self.nodeid!r} when={self.when!r} outcome={self.outcome!r}>'
@classmethod
def from_item_and_call(cls, item: Item, call: "CallInfo[None]") -> "TestReport":
def from_item_and_call(cls, item: Item, call: CallInfo[None]) -> TestReport:
"""Create and fill a TestReport with standard item and call info.
:param item: The item.
@ -326,7 +328,7 @@ class TestReport(BaseReport):
"""
when = call.when
# Remove "collect" from the Literal type -- only for collection calls.
assert when != "collect"
assert when != 'collect'
duration = call.duration
start = call.start
stop = call.stop
@ -334,24 +336,24 @@ class TestReport(BaseReport):
excinfo = call.excinfo
sections = []
if not call.excinfo:
outcome: Literal["passed", "failed", "skipped"] = "passed"
longrepr: Union[
None,
ExceptionInfo[BaseException],
Tuple[str, int, str],
str,
TerminalRepr,
] = None
outcome: Literal['passed', 'failed', 'skipped'] = 'passed'
longrepr: (
None |
ExceptionInfo[BaseException] |
tuple[str, int, str] |
str |
TerminalRepr
) = None
else:
if not isinstance(excinfo, ExceptionInfo):
outcome = "failed"
outcome = 'failed'
longrepr = excinfo
elif isinstance(excinfo.value, skip.Exception):
outcome = "skipped"
outcome = 'skipped'
r = excinfo._getreprcrash()
assert (
r is not None
), "There should always be a traceback entry for skipping a test."
), 'There should always be a traceback entry for skipping a test.'
if excinfo.value._use_item_location:
path, line = item.reportinfo()[:2]
assert line is not None
@ -359,15 +361,15 @@ class TestReport(BaseReport):
else:
longrepr = (str(r.path), r.lineno, r.message)
else:
outcome = "failed"
if call.when == "call":
outcome = 'failed'
if call.when == 'call':
longrepr = item.repr_failure(excinfo)
else: # exception in setup or teardown
longrepr = item._repr_failure_py(
excinfo, style=item.config.getoption("tbstyle", "auto")
excinfo, style=item.config.getoption('tbstyle', 'auto'),
)
for rwhen, key, content in item._report_sections:
sections.append((f"Captured {key} {rwhen}", content))
sections.append((f'Captured {key} {rwhen}', content))
return cls(
item.nodeid,
item.location,
@ -390,17 +392,17 @@ class CollectReport(BaseReport):
Reports can contain arbitrary extra attributes.
"""
when = "collect"
when = 'collect'
def __init__(
self,
nodeid: str,
outcome: "Literal['passed', 'failed', 'skipped']",
longrepr: Union[
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr
],
result: Optional[List[Union[Item, Collector]]],
sections: Iterable[Tuple[str, str]] = (),
outcome: Literal['passed', 'failed', 'skipped'],
longrepr: (
None | ExceptionInfo[BaseException] | tuple[str, int, str] | str | TerminalRepr
),
result: list[Item | Collector] | None,
sections: Iterable[tuple[str, str]] = (),
**extra,
) -> None:
#: Normalized collection nodeid.
@ -426,11 +428,11 @@ class CollectReport(BaseReport):
@property
def location( # type:ignore[override]
self,
) -> Optional[Tuple[str, Optional[int], str]]:
) -> tuple[str, int | None, str] | None:
return (self.fspath, None, self.fspath)
def __repr__(self) -> str:
return f"<CollectReport {self.nodeid!r} lenresult={len(self.result)} outcome={self.outcome!r}>"
return f'<CollectReport {self.nodeid!r} lenresult={len(self.result)} outcome={self.outcome!r}>'
class CollectErrorRepr(TerminalRepr):
@ -442,31 +444,31 @@ class CollectErrorRepr(TerminalRepr):
def pytest_report_to_serializable(
report: Union[CollectReport, TestReport],
) -> Optional[Dict[str, Any]]:
report: CollectReport | TestReport,
) -> dict[str, Any] | None:
if isinstance(report, (TestReport, CollectReport)):
data = report._to_json()
data["$report_type"] = report.__class__.__name__
data['$report_type'] = report.__class__.__name__
return data
# TODO: Check if this is actually reachable.
return None # type: ignore[unreachable]
def pytest_report_from_serializable(
data: Dict[str, Any],
) -> Optional[Union[CollectReport, TestReport]]:
if "$report_type" in data:
if data["$report_type"] == "TestReport":
data: dict[str, Any],
) -> CollectReport | TestReport | None:
if '$report_type' in data:
if data['$report_type'] == 'TestReport':
return TestReport._from_json(data)
elif data["$report_type"] == "CollectReport":
elif data['$report_type'] == 'CollectReport':
return CollectReport._from_json(data)
assert False, "Unknown report_type unserialize data: {}".format(
data["$report_type"]
assert False, 'Unknown report_type unserialize data: {}'.format(
data['$report_type'],
)
return None
def _report_to_json(report: BaseReport) -> Dict[str, Any]:
def _report_to_json(report: BaseReport) -> dict[str, Any]:
"""Return the contents of this report as a dict of builtin entries,
suitable for serialization.
@ -474,72 +476,72 @@ def _report_to_json(report: BaseReport) -> Dict[str, Any]:
"""
def serialize_repr_entry(
entry: Union[ReprEntry, ReprEntryNative],
) -> Dict[str, Any]:
entry: ReprEntry | ReprEntryNative,
) -> dict[str, Any]:
data = dataclasses.asdict(entry)
for key, value in data.items():
if hasattr(value, "__dict__"):
if hasattr(value, '__dict__'):
data[key] = dataclasses.asdict(value)
entry_data = {"type": type(entry).__name__, "data": data}
entry_data = {'type': type(entry).__name__, 'data': data}
return entry_data
def serialize_repr_traceback(reprtraceback: ReprTraceback) -> Dict[str, Any]:
def serialize_repr_traceback(reprtraceback: ReprTraceback) -> dict[str, Any]:
result = dataclasses.asdict(reprtraceback)
result["reprentries"] = [
result['reprentries'] = [
serialize_repr_entry(x) for x in reprtraceback.reprentries
]
return result
def serialize_repr_crash(
reprcrash: Optional[ReprFileLocation],
) -> Optional[Dict[str, Any]]:
reprcrash: ReprFileLocation | None,
) -> dict[str, Any] | None:
if reprcrash is not None:
return dataclasses.asdict(reprcrash)
else:
return None
def serialize_exception_longrepr(rep: BaseReport) -> Dict[str, Any]:
def serialize_exception_longrepr(rep: BaseReport) -> dict[str, Any]:
assert rep.longrepr is not None
# TODO: Investigate whether the duck typing is really necessary here.
longrepr = cast(ExceptionRepr, rep.longrepr)
result: Dict[str, Any] = {
"reprcrash": serialize_repr_crash(longrepr.reprcrash),
"reprtraceback": serialize_repr_traceback(longrepr.reprtraceback),
"sections": longrepr.sections,
result: dict[str, Any] = {
'reprcrash': serialize_repr_crash(longrepr.reprcrash),
'reprtraceback': serialize_repr_traceback(longrepr.reprtraceback),
'sections': longrepr.sections,
}
if isinstance(longrepr, ExceptionChainRepr):
result["chain"] = []
result['chain'] = []
for repr_traceback, repr_crash, description in longrepr.chain:
result["chain"].append(
result['chain'].append(
(
serialize_repr_traceback(repr_traceback),
serialize_repr_crash(repr_crash),
description,
)
),
)
else:
result["chain"] = None
result['chain'] = None
return result
d = report.__dict__.copy()
if hasattr(report.longrepr, "toterminal"):
if hasattr(report.longrepr, "reprtraceback") and hasattr(
report.longrepr, "reprcrash"
if hasattr(report.longrepr, 'toterminal'):
if hasattr(report.longrepr, 'reprtraceback') and hasattr(
report.longrepr, 'reprcrash',
):
d["longrepr"] = serialize_exception_longrepr(report)
d['longrepr'] = serialize_exception_longrepr(report)
else:
d["longrepr"] = str(report.longrepr)
d['longrepr'] = str(report.longrepr)
else:
d["longrepr"] = report.longrepr
d['longrepr'] = report.longrepr
for name in d:
if isinstance(d[name], os.PathLike):
d[name] = os.fspath(d[name])
elif name == "result":
elif name == 'result':
d[name] = None # for now
return d
def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
def _report_kwargs_from_json(reportdict: dict[str, Any]) -> dict[str, Any]:
"""Return **kwargs that can be used to construct a TestReport or
CollectReport instance.
@ -547,76 +549,76 @@ def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
"""
def deserialize_repr_entry(entry_data):
data = entry_data["data"]
entry_type = entry_data["type"]
if entry_type == "ReprEntry":
data = entry_data['data']
entry_type = entry_data['type']
if entry_type == 'ReprEntry':
reprfuncargs = None
reprfileloc = None
reprlocals = None
if data["reprfuncargs"]:
reprfuncargs = ReprFuncArgs(**data["reprfuncargs"])
if data["reprfileloc"]:
reprfileloc = ReprFileLocation(**data["reprfileloc"])
if data["reprlocals"]:
reprlocals = ReprLocals(data["reprlocals"]["lines"])
if data['reprfuncargs']:
reprfuncargs = ReprFuncArgs(**data['reprfuncargs'])
if data['reprfileloc']:
reprfileloc = ReprFileLocation(**data['reprfileloc'])
if data['reprlocals']:
reprlocals = ReprLocals(data['reprlocals']['lines'])
reprentry: Union[ReprEntry, ReprEntryNative] = ReprEntry(
lines=data["lines"],
reprentry: ReprEntry | ReprEntryNative = ReprEntry(
lines=data['lines'],
reprfuncargs=reprfuncargs,
reprlocals=reprlocals,
reprfileloc=reprfileloc,
style=data["style"],
style=data['style'],
)
elif entry_type == "ReprEntryNative":
reprentry = ReprEntryNative(data["lines"])
elif entry_type == 'ReprEntryNative':
reprentry = ReprEntryNative(data['lines'])
else:
_report_unserialization_failure(entry_type, TestReport, reportdict)
return reprentry
def deserialize_repr_traceback(repr_traceback_dict):
repr_traceback_dict["reprentries"] = [
deserialize_repr_entry(x) for x in repr_traceback_dict["reprentries"]
repr_traceback_dict['reprentries'] = [
deserialize_repr_entry(x) for x in repr_traceback_dict['reprentries']
]
return ReprTraceback(**repr_traceback_dict)
def deserialize_repr_crash(repr_crash_dict: Optional[Dict[str, Any]]):
def deserialize_repr_crash(repr_crash_dict: dict[str, Any] | None):
if repr_crash_dict is not None:
return ReprFileLocation(**repr_crash_dict)
else:
return None
if (
reportdict["longrepr"]
and "reprcrash" in reportdict["longrepr"]
and "reprtraceback" in reportdict["longrepr"]
reportdict['longrepr'] and
'reprcrash' in reportdict['longrepr'] and
'reprtraceback' in reportdict['longrepr']
):
reprtraceback = deserialize_repr_traceback(
reportdict["longrepr"]["reprtraceback"]
reportdict['longrepr']['reprtraceback'],
)
reprcrash = deserialize_repr_crash(reportdict["longrepr"]["reprcrash"])
if reportdict["longrepr"]["chain"]:
reprcrash = deserialize_repr_crash(reportdict['longrepr']['reprcrash'])
if reportdict['longrepr']['chain']:
chain = []
for repr_traceback_data, repr_crash_data, description in reportdict[
"longrepr"
]["chain"]:
'longrepr'
]['chain']:
chain.append(
(
deserialize_repr_traceback(repr_traceback_data),
deserialize_repr_crash(repr_crash_data),
description,
)
),
)
exception_info: Union[
ExceptionChainRepr, ReprExceptionInfo
] = ExceptionChainRepr(chain)
exception_info: (
ExceptionChainRepr | ReprExceptionInfo
) = ExceptionChainRepr(chain)
else:
exception_info = ReprExceptionInfo(
reprtraceback=reprtraceback,
reprcrash=reprcrash,
)
for section in reportdict["longrepr"]["sections"]:
for section in reportdict['longrepr']['sections']:
exception_info.addsection(*section)
reportdict["longrepr"] = exception_info
reportdict['longrepr'] = exception_info
return reportdict

View file

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs
"""Basic collect and runtest protocol implementations."""
from __future__ import annotations
import bdb
import dataclasses
import os
@ -18,10 +20,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .reports import BaseReport
from .reports import CollectErrorRepr
from .reports import CollectReport
from .reports import TestReport
from _pytest import timing
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
@ -37,6 +35,11 @@ from _pytest.outcomes import OutcomeException
from _pytest.outcomes import Skipped
from _pytest.outcomes import TEST_OUTCOME
from .reports import BaseReport
from .reports import CollectErrorRepr
from .reports import CollectReport
from .reports import TestReport
if sys.version_info[:2] < (3, 11):
from exceptiongroup import BaseExceptionGroup
@ -50,66 +53,66 @@ if TYPE_CHECKING:
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting", "Reporting", after="general")
group = parser.getgroup('terminal reporting', 'Reporting', after='general')
group.addoption(
"--durations",
action="store",
'--durations',
action='store',
type=int,
default=None,
metavar="N",
help="Show N slowest setup/test durations (N=0 for all)",
metavar='N',
help='Show N slowest setup/test durations (N=0 for all)',
)
group.addoption(
"--durations-min",
action="store",
'--durations-min',
action='store',
type=float,
default=0.005,
metavar="N",
help="Minimal duration in seconds for inclusion in slowest list. "
"Default: 0.005.",
metavar='N',
help='Minimal duration in seconds for inclusion in slowest list. '
'Default: 0.005.',
)
def pytest_terminal_summary(terminalreporter: "TerminalReporter") -> None:
def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
durations = terminalreporter.config.option.durations
durations_min = terminalreporter.config.option.durations_min
verbose = terminalreporter.config.getvalue("verbose")
verbose = terminalreporter.config.getvalue('verbose')
if durations is None:
return
tr = terminalreporter
dlist = []
for replist in tr.stats.values():
for rep in replist:
if hasattr(rep, "duration"):
if hasattr(rep, 'duration'):
dlist.append(rep)
if not dlist:
return
dlist.sort(key=lambda x: x.duration, reverse=True) # type: ignore[no-any-return]
if not durations:
tr.write_sep("=", "slowest durations")
tr.write_sep('=', 'slowest durations')
else:
tr.write_sep("=", "slowest %s durations" % durations)
tr.write_sep('=', 'slowest %s durations' % durations)
dlist = dlist[:durations]
for i, rep in enumerate(dlist):
if verbose < 2 and rep.duration < durations_min:
tr.write_line("")
tr.write_line('')
tr.write_line(
f"({len(dlist) - i} durations < {durations_min:g}s hidden. Use -vv to show these durations.)"
f'({len(dlist) - i} durations < {durations_min:g}s hidden. Use -vv to show these durations.)',
)
break
tr.write_line(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}")
tr.write_line(f'{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}')
def pytest_sessionstart(session: "Session") -> None:
def pytest_sessionstart(session: Session) -> None:
session._setupstate = SetupState()
def pytest_sessionfinish(session: "Session") -> None:
def pytest_sessionfinish(session: Session) -> None:
session._setupstate.teardown_exact(None)
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
def pytest_runtest_protocol(item: Item, nextitem: Item | None) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
runtestprotocol(item, nextitem=nextitem)
@ -118,21 +121,21 @@ def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
item: Item, log: bool = True, nextitem: Item | None = None,
) -> list[TestReport]:
hasrequest = hasattr(item, '_request')
if hasrequest and not item._request: # type: ignore[attr-defined]
# This only happens if the item is re-run, as is done by
# pytest-rerunfailures.
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
rep = call_and_report(item, 'setup', log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
if item.config.getoption('setupshow', False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
reports.append(call_and_report(item, "call", log))
reports.append(call_and_report(item, "teardown", log, nextitem=nextitem))
if not item.config.getoption('setuponly', False):
reports.append(call_and_report(item, 'call', log))
reports.append(call_and_report(item, 'teardown', log, nextitem=nextitem))
# After all teardown hooks have been called
# want funcargs and request info to go away.
if hasrequest:
@ -145,21 +148,21 @@ def show_test_item(item: Item) -> None:
"""Show test function, parameters and the fixtures of the test item."""
tw = item.config.get_terminal_writer()
tw.line()
tw.write(" " * 8)
tw.write(' ' * 8)
tw.write(item.nodeid)
used_fixtures = sorted(getattr(item, "fixturenames", []))
used_fixtures = sorted(getattr(item, 'fixturenames', []))
if used_fixtures:
tw.write(" (fixtures used: {})".format(", ".join(used_fixtures)))
tw.write(' (fixtures used: {})'.format(', '.join(used_fixtures)))
tw.flush()
def pytest_runtest_setup(item: Item) -> None:
_update_current_test_var(item, "setup")
_update_current_test_var(item, 'setup')
item.session._setupstate.setup(item)
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
_update_current_test_var(item, 'call')
try:
del sys.last_type
del sys.last_value
@ -182,38 +185,38 @@ def pytest_runtest_call(item: Item) -> None:
raise e
def pytest_runtest_teardown(item: Item, nextitem: Optional[Item]) -> None:
_update_current_test_var(item, "teardown")
def pytest_runtest_teardown(item: Item, nextitem: Item | None) -> None:
_update_current_test_var(item, 'teardown')
item.session._setupstate.teardown_exact(nextitem)
_update_current_test_var(item, None)
def _update_current_test_var(
item: Item, when: Optional[Literal["setup", "call", "teardown"]]
item: Item, when: Literal['setup', 'call', 'teardown'] | None,
) -> None:
"""Update :envvar:`PYTEST_CURRENT_TEST` to reflect the current item and stage.
If ``when`` is None, delete ``PYTEST_CURRENT_TEST`` from the environment.
"""
var_name = "PYTEST_CURRENT_TEST"
var_name = 'PYTEST_CURRENT_TEST'
if when:
value = f"{item.nodeid} ({when})"
value = f'{item.nodeid} ({when})'
# don't allow null bytes on environment variables (see #2644, #2957)
value = value.replace("\x00", "(null)")
value = value.replace('\x00', '(null)')
os.environ[var_name] = value
else:
os.environ.pop(var_name)
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]:
if report.when in ("setup", "teardown"):
def pytest_report_teststatus(report: BaseReport) -> tuple[str, str, str] | None:
if report.when in ('setup', 'teardown'):
if report.failed:
# category, shortletter, verbose-word
return "error", "E", "ERROR"
return 'error', 'E', 'ERROR'
elif report.skipped:
return "skipped", "s", "SKIPPED"
return 'skipped', 's', 'SKIPPED'
else:
return "", "", ""
return '', '', ''
return None
@ -222,22 +225,22 @@ def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str
def call_and_report(
item: Item, when: Literal["setup", "call", "teardown"], log: bool = True, **kwds
item: Item, when: Literal['setup', 'call', 'teardown'], log: bool = True, **kwds,
) -> TestReport:
ihook = item.ihook
if when == "setup":
if when == 'setup':
runtest_hook: Callable[..., None] = ihook.pytest_runtest_setup
elif when == "call":
elif when == 'call':
runtest_hook = ihook.pytest_runtest_call
elif when == "teardown":
elif when == 'teardown':
runtest_hook = ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
assert False, f'Unhandled runtest hook case: {when}'
reraise: tuple[type[BaseException], ...] = (Exit,)
if not item.config.getoption('usepdb', False):
reraise += (KeyboardInterrupt,)
call = CallInfo.from_call(
lambda: runtest_hook(item=item, **kwds), when=when, reraise=reraise
lambda: runtest_hook(item=item, **kwds), when=when, reraise=reraise,
)
report: TestReport = ihook.pytest_runtest_makereport(item=item, call=call)
if log:
@ -247,13 +250,13 @@ def call_and_report(
return report
def check_interactive_exception(call: "CallInfo[object]", report: BaseReport) -> bool:
def check_interactive_exception(call: CallInfo[object], report: BaseReport) -> bool:
"""Check whether the call raised an exception that should be reported as
interactive."""
if call.excinfo is None:
# Didn't raise.
return False
if hasattr(report, "wasxfail"):
if hasattr(report, 'wasxfail'):
# Exception was expected.
return False
if isinstance(call.excinfo.value, (Skipped, bdb.BdbQuit)):
@ -262,7 +265,7 @@ def check_interactive_exception(call: "CallInfo[object]", report: BaseReport) ->
return True
TResult = TypeVar("TResult", covariant=True)
TResult = TypeVar('TResult', covariant=True)
@final
@ -270,9 +273,9 @@ TResult = TypeVar("TResult", covariant=True)
class CallInfo(Generic[TResult]):
"""Result/Exception info of a function invocation."""
_result: Optional[TResult]
_result: TResult | None
#: The captured exception of the call, if it raised.
excinfo: Optional[ExceptionInfo[BaseException]]
excinfo: ExceptionInfo[BaseException] | None
#: The system time when the call started, in seconds since the epoch.
start: float
#: The system time when the call ended, in seconds since the epoch.
@ -280,16 +283,16 @@ class CallInfo(Generic[TResult]):
#: The call duration, in seconds.
duration: float
#: The context of invocation: "collect", "setup", "call" or "teardown".
when: Literal["collect", "setup", "call", "teardown"]
when: Literal['collect', 'setup', 'call', 'teardown']
def __init__(
self,
result: Optional[TResult],
excinfo: Optional[ExceptionInfo[BaseException]],
result: TResult | None,
excinfo: ExceptionInfo[BaseException] | None,
start: float,
stop: float,
duration: float,
when: Literal["collect", "setup", "call", "teardown"],
when: Literal['collect', 'setup', 'call', 'teardown'],
*,
_ispytest: bool = False,
) -> None:
@ -308,7 +311,7 @@ class CallInfo(Generic[TResult]):
Can only be accessed if excinfo is None.
"""
if self.excinfo is not None:
raise AttributeError(f"{self!r} has no valid result")
raise AttributeError(f'{self!r} has no valid result')
# The cast is safe because an exception wasn't raised, hence
# _result has the expected function return type (which may be
# None, that's why a cast and not an assert).
@ -318,11 +321,11 @@ class CallInfo(Generic[TResult]):
def from_call(
cls,
func: Callable[[], TResult],
when: Literal["collect", "setup", "call", "teardown"],
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
when: Literal['collect', 'setup', 'call', 'teardown'],
reraise: None | (
type[BaseException] | tuple[type[BaseException], ...]
) = None,
) -> CallInfo[TResult]:
"""Call func, wrapping the result in a CallInfo.
:param func:
@ -337,7 +340,7 @@ class CallInfo(Generic[TResult]):
start = timing.time()
precise_start = timing.perf_counter()
try:
result: Optional[TResult] = func()
result: TResult | None = func()
except BaseException:
excinfo = ExceptionInfo.from_current()
if reraise is not None and isinstance(excinfo.value, reraise):
@ -359,8 +362,8 @@ class CallInfo(Generic[TResult]):
def __repr__(self) -> str:
if self.excinfo is None:
return f"<CallInfo when={self.when!r} result: {self._result!r}>"
return f"<CallInfo when={self.when!r} excinfo={self.excinfo!r}>"
return f'<CallInfo when={self.when!r} result: {self._result!r}>'
return f'<CallInfo when={self.when!r} excinfo={self.excinfo!r}>'
def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport:
@ -368,7 +371,7 @@ def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport:
def pytest_make_collect_report(collector: Collector) -> CollectReport:
def collect() -> List[Union[Item, Collector]]:
def collect() -> list[Item | Collector]:
# Before collecting, if this is a Directory, load the conftests.
# If a conftest import fails to load, it is considered a collection
# error of the Directory collector. This is why it's done inside of the
@ -378,36 +381,36 @@ def pytest_make_collect_report(collector: Collector) -> CollectReport:
if isinstance(collector, Directory):
collector.config.pluginmanager._loadconftestmodules(
collector.path,
collector.config.getoption("importmode"),
collector.config.getoption('importmode'),
rootpath=collector.config.rootpath,
consider_namespace_packages=collector.config.getini(
"consider_namespace_packages"
'consider_namespace_packages',
),
)
return list(collector.collect())
call = CallInfo.from_call(collect, "collect")
longrepr: Union[None, Tuple[str, int, str], str, TerminalRepr] = None
call = CallInfo.from_call(collect, 'collect')
longrepr: None | tuple[str, int, str] | str | TerminalRepr = None
if not call.excinfo:
outcome: Literal["passed", "skipped", "failed"] = "passed"
outcome: Literal['passed', 'skipped', 'failed'] = 'passed'
else:
skip_exceptions = [Skipped]
unittest = sys.modules.get("unittest")
unittest = sys.modules.get('unittest')
if unittest is not None:
# Type ignored because unittest is loaded dynamically.
skip_exceptions.append(unittest.SkipTest) # type: ignore
if isinstance(call.excinfo.value, tuple(skip_exceptions)):
outcome = "skipped"
r_ = collector._repr_failure_py(call.excinfo, "line")
outcome = 'skipped'
r_ = collector._repr_failure_py(call.excinfo, 'line')
assert isinstance(r_, ExceptionChainRepr), repr(r_)
r = r_.reprcrash
assert r
longrepr = (str(r.path), r.lineno, r.message)
else:
outcome = "failed"
outcome = 'failed'
errorinfo = collector.repr_failure(call.excinfo)
if not hasattr(errorinfo, "toterminal"):
if not hasattr(errorinfo, 'toterminal'):
assert isinstance(errorinfo, str)
errorinfo = CollectErrorRepr(errorinfo)
longrepr = errorinfo
@ -483,13 +486,13 @@ class SetupState:
def __init__(self) -> None:
# The stack is in the dict insertion order.
self.stack: Dict[
self.stack: dict[
Node,
Tuple[
tuple[
# Node's finalizers.
List[Callable[[], object]],
list[Callable[[], object]],
# Node's exception, if its setup raised.
Optional[Union[OutcomeException, Exception]],
OutcomeException | Exception | None,
],
] = {}
@ -500,11 +503,11 @@ class SetupState:
# If a collector fails its setup, fail its entire subtree of items.
# The setup is not retried for each item - the same exception is used.
for col, (finalizers, exc) in self.stack.items():
assert col in needed_collectors, "previous item was not torn down properly"
assert col in needed_collectors, 'previous item was not torn down properly'
if exc:
raise exc
for col in needed_collectors[len(self.stack) :]:
for col in needed_collectors[len(self.stack):]:
assert col not in self.stack
# Push onto the stack.
self.stack[col] = ([col.teardown], None)
@ -524,7 +527,7 @@ class SetupState:
assert node in self.stack, (node, self.stack)
self.stack[node][0].append(finalizer)
def teardown_exact(self, nextitem: Optional[Item]) -> None:
def teardown_exact(self, nextitem: Item | None) -> None:
"""Teardown the current stack up until reaching nodes that nextitem
also descends from.
@ -532,7 +535,7 @@ class SetupState:
stack is torn down.
"""
needed_collectors = nextitem and nextitem.listchain() or []
exceptions: List[BaseException] = []
exceptions: list[BaseException] = []
while self.stack:
if list(self.stack.keys()) == needed_collectors[: len(self.stack)]:
break
@ -548,13 +551,13 @@ class SetupState:
if len(these_exceptions) == 1:
exceptions.extend(these_exceptions)
elif these_exceptions:
msg = f"errors while tearing down {node!r}"
msg = f'errors while tearing down {node!r}'
exceptions.append(BaseExceptionGroup(msg, these_exceptions[::-1]))
if len(exceptions) == 1:
raise exceptions[0]
elif exceptions:
raise BaseExceptionGroup("errors during test teardown", exceptions[::-1])
raise BaseExceptionGroup('errors during test teardown', exceptions[::-1])
if nextitem is None:
assert not self.stack
@ -563,7 +566,7 @@ def collect_one_node(collector: Collector) -> CollectReport:
ihook = collector.ihook
ihook.pytest_collectstart(collector=collector)
rep: CollectReport = ihook.pytest_make_collect_report(collector=collector)
call = rep.__dict__.pop("call", None)
call = rep.__dict__.pop('call', None)
if call and check_interactive_exception(call, rep):
ihook.pytest_exception_interact(node=collector, call=call, report=rep)
return rep

View file

@ -7,6 +7,7 @@ would cause circular references.
Also this makes the module light to import, as it should.
"""
from __future__ import annotations
from enum import Enum
from functools import total_ordering
@ -14,7 +15,7 @@ from typing import Literal
from typing import Optional
_ScopeName = Literal["session", "package", "module", "class", "function"]
_ScopeName = Literal['session', 'package', 'module', 'class', 'function']
@total_ordering
@ -32,35 +33,35 @@ class Scope(Enum):
"""
# Scopes need to be listed from lower to higher.
Function: _ScopeName = "function"
Class: _ScopeName = "class"
Module: _ScopeName = "module"
Package: _ScopeName = "package"
Session: _ScopeName = "session"
Function: _ScopeName = 'function'
Class: _ScopeName = 'class'
Module: _ScopeName = 'module'
Package: _ScopeName = 'package'
Session: _ScopeName = 'session'
def next_lower(self) -> "Scope":
def next_lower(self) -> Scope:
"""Return the next lower scope."""
index = _SCOPE_INDICES[self]
if index == 0:
raise ValueError(f"{self} is the lower-most scope")
raise ValueError(f'{self} is the lower-most scope')
return _ALL_SCOPES[index - 1]
def next_higher(self) -> "Scope":
def next_higher(self) -> Scope:
"""Return the next higher scope."""
index = _SCOPE_INDICES[self]
if index == len(_SCOPE_INDICES) - 1:
raise ValueError(f"{self} is the upper-most scope")
raise ValueError(f'{self} is the upper-most scope')
return _ALL_SCOPES[index + 1]
def __lt__(self, other: "Scope") -> bool:
def __lt__(self, other: Scope) -> bool:
self_index = _SCOPE_INDICES[self]
other_index = _SCOPE_INDICES[other]
return self_index < other_index
@classmethod
def from_user(
cls, scope_name: _ScopeName, descr: str, where: Optional[str] = None
) -> "Scope":
cls, scope_name: _ScopeName, descr: str, where: str | None = None,
) -> Scope:
"""
Given a scope name from the user, return the equivalent Scope enum. Should be used
whenever we want to convert a user provided scope name to its enum object.
@ -75,7 +76,7 @@ class Scope(Enum):
except ValueError:
fail(
"{} {}got an unexpected scope value '{}'".format(
descr, f"from {where} " if where else "", scope_name
descr, f'from {where} ' if where else '', scope_name,
),
pytrace=False,
)

View file

@ -1,7 +1,10 @@
from __future__ import annotations
from typing import Generator
from typing import Optional
from typing import Union
import pytest
from _pytest._io.saferepr import saferepr
from _pytest.config import Config
from _pytest.config import ExitCode
@ -9,34 +12,33 @@ from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest
from _pytest.scope import Scope
import pytest
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group = parser.getgroup('debugconfig')
group.addoption(
"--setuponly",
"--setup-only",
action="store_true",
help="Only setup fixtures, do not execute tests",
'--setuponly',
'--setup-only',
action='store_true',
help='Only setup fixtures, do not execute tests',
)
group.addoption(
"--setupshow",
"--setup-show",
action="store_true",
help="Show setup of fixtures while executing tests",
'--setupshow',
'--setup-show',
action='store_true',
help='Show setup of fixtures while executing tests',
)
@pytest.hookimpl(wrapper=True)
def pytest_fixture_setup(
fixturedef: FixtureDef[object], request: SubRequest
fixturedef: FixtureDef[object], request: SubRequest,
) -> Generator[None, object, object]:
try:
return (yield)
finally:
if request.config.option.setupshow:
if hasattr(request, "param"):
if hasattr(request, 'param'):
# Save the fixture parameter so ._show_fixture_action() can
# display it now and during the teardown (in .finish()).
if fixturedef.ids:
@ -47,24 +49,24 @@ def pytest_fixture_setup(
else:
param = request.param
fixturedef.cached_param = param # type: ignore[attr-defined]
_show_fixture_action(fixturedef, request.config, "SETUP")
_show_fixture_action(fixturedef, request.config, 'SETUP')
def pytest_fixture_post_finalizer(
fixturedef: FixtureDef[object], request: SubRequest
fixturedef: FixtureDef[object], request: SubRequest,
) -> None:
if fixturedef.cached_result is not None:
config = request.config
if config.option.setupshow:
_show_fixture_action(fixturedef, request.config, "TEARDOWN")
if hasattr(fixturedef, "cached_param"):
_show_fixture_action(fixturedef, request.config, 'TEARDOWN')
if hasattr(fixturedef, 'cached_param'):
del fixturedef.cached_param # type: ignore[attr-defined]
def _show_fixture_action(
fixturedef: FixtureDef[object], config: Config, msg: str
fixturedef: FixtureDef[object], config: Config, msg: str,
) -> None:
capman = config.pluginmanager.getplugin("capturemanager")
capman = config.pluginmanager.getplugin('capturemanager')
if capman:
capman.suspend_global_capture()
@ -72,22 +74,22 @@ def _show_fixture_action(
tw.line()
# Use smaller indentation the higher the scope: Session = 0, Package = 1, etc.
scope_indent = list(reversed(Scope)).index(fixturedef._scope)
tw.write(" " * 2 * scope_indent)
tw.write(' ' * 2 * scope_indent)
tw.write(
"{step} {scope} {fixture}".format( # noqa: UP032 (Readability)
'{step} {scope} {fixture}'.format( # noqa: UP032 (Readability)
step=msg.ljust(8), # align the output to TEARDOWN
scope=fixturedef.scope[0].upper(),
fixture=fixturedef.argname,
)
),
)
if msg == "SETUP":
deps = sorted(arg for arg in fixturedef.argnames if arg != "request")
if msg == 'SETUP':
deps = sorted(arg for arg in fixturedef.argnames if arg != 'request')
if deps:
tw.write(" (fixtures used: {})".format(", ".join(deps)))
tw.write(' (fixtures used: {})'.format(', '.join(deps)))
if hasattr(fixturedef, "cached_param"):
tw.write(f"[{saferepr(fixturedef.cached_param, maxsize=42)}]") # type: ignore[attr-defined]
if hasattr(fixturedef, 'cached_param'):
tw.write(f'[{saferepr(fixturedef.cached_param, maxsize=42)}]') # type: ignore[attr-defined]
tw.flush()
@ -96,7 +98,7 @@ def _show_fixture_action(
@pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.setuponly:
config.option.setupshow = True
return None

View file

@ -1,29 +1,31 @@
from __future__ import annotations
from typing import Optional
from typing import Union
import pytest
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest
import pytest
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group = parser.getgroup('debugconfig')
group.addoption(
"--setupplan",
"--setup-plan",
action="store_true",
help="Show what fixtures and tests would be executed but "
'--setupplan',
'--setup-plan',
action='store_true',
help='Show what fixtures and tests would be executed but '
"don't execute anything",
)
@pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup(
fixturedef: FixtureDef[object], request: SubRequest
) -> Optional[object]:
fixturedef: FixtureDef[object], request: SubRequest,
) -> object | None:
# Will return a dummy fixture if the setuponly option is provided.
if request.config.option.setupplan:
my_cache_key = fixturedef.cache_key(request)
@ -33,7 +35,7 @@ def pytest_fixture_setup(
@pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.setupplan:
config.option.setuponly = True
config.option.setupshow = True

View file

@ -1,11 +1,13 @@
# mypy: allow-untyped-defs
"""Support for skip/xfail functions and markers."""
from collections.abc import Mapping
from __future__ import annotations
import dataclasses
import os
import platform
import sys
import traceback
from collections.abc import Mapping
from typing import Generator
from typing import Optional
from typing import Tuple
@ -26,21 +28,21 @@ from _pytest.stash import StashKey
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group = parser.getgroup('general')
group.addoption(
"--runxfail",
action="store_true",
dest="runxfail",
'--runxfail',
action='store_true',
dest='runxfail',
default=False,
help="Report the results of xfail tests as if they were not marked",
help='Report the results of xfail tests as if they were not marked',
)
parser.addini(
"xfail_strict",
"Default for the strict parameter of xfail "
"markers when not given explicitly (default: False)",
'xfail_strict',
'Default for the strict parameter of xfail '
'markers when not given explicitly (default: False)',
default=False,
type="bool",
type='bool',
)
@ -50,40 +52,40 @@ def pytest_configure(config: Config) -> None:
import pytest
old = pytest.xfail
config.add_cleanup(lambda: setattr(pytest, "xfail", old))
config.add_cleanup(lambda: setattr(pytest, 'xfail', old))
def nop(*args, **kwargs):
pass
nop.Exception = xfail.Exception # type: ignore[attr-defined]
setattr(pytest, "xfail", nop)
setattr(pytest, 'xfail', nop)
config.addinivalue_line(
"markers",
"skip(reason=None): skip the given test function with an optional reason. "
'markers',
'skip(reason=None): skip the given test function with an optional reason. '
'Example: skip(reason="no way of currently testing this") skips the '
"test.",
'test.',
)
config.addinivalue_line(
"markers",
"skipif(condition, ..., *, reason=...): "
"skip the given test function if any of the conditions evaluate to True. "
'markers',
'skipif(condition, ..., *, reason=...): '
'skip the given test function if any of the conditions evaluate to True. '
"Example: skipif(sys.platform == 'win32') skips the test if we are on the win32 platform. "
"See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-skipif",
'See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-skipif',
)
config.addinivalue_line(
"markers",
"xfail(condition, ..., *, reason=..., run=True, raises=None, strict=xfail_strict): "
"mark the test function as an expected failure if any of the conditions "
"evaluate to True. Optionally specify a reason for better reporting "
'markers',
'xfail(condition, ..., *, reason=..., run=True, raises=None, strict=xfail_strict): '
'mark the test function as an expected failure if any of the conditions '
'evaluate to True. Optionally specify a reason for better reporting '
"and run=False if you don't even want to execute the test function. "
"If only specific exception(s) are expected, you can list them in "
"raises, and if the test fails in other ways, it will be reported as "
"a true failure. See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-xfail",
'If only specific exception(s) are expected, you can list them in '
'raises, and if the test fails in other ways, it will be reported as '
'a true failure. See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-xfail',
)
def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool, str]:
def evaluate_condition(item: Item, mark: Mark, condition: object) -> tuple[bool, str]:
"""Evaluate a single skipif/xfail condition.
If an old-style string condition is given, it is eval()'d, otherwise the
@ -95,40 +97,40 @@ def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool,
# String condition.
if isinstance(condition, str):
globals_ = {
"os": os,
"sys": sys,
"platform": platform,
"config": item.config,
'os': os,
'sys': sys,
'platform': platform,
'config': item.config,
}
for dictionary in reversed(
item.ihook.pytest_markeval_namespace(config=item.config)
item.ihook.pytest_markeval_namespace(config=item.config),
):
if not isinstance(dictionary, Mapping):
raise ValueError(
f"pytest_markeval_namespace() needs to return a dict, got {dictionary!r}"
f'pytest_markeval_namespace() needs to return a dict, got {dictionary!r}',
)
globals_.update(dictionary)
if hasattr(item, "obj"):
if hasattr(item, 'obj'):
globals_.update(item.obj.__globals__) # type: ignore[attr-defined]
try:
filename = f"<{mark.name} condition>"
condition_code = compile(condition, filename, "eval")
filename = f'<{mark.name} condition>'
condition_code = compile(condition, filename, 'eval')
result = eval(condition_code, globals_)
except SyntaxError as exc:
msglines = [
"Error evaluating %r condition" % mark.name,
" " + condition,
" " + " " * (exc.offset or 0) + "^",
"SyntaxError: invalid syntax",
'Error evaluating %r condition' % mark.name,
' ' + condition,
' ' + ' ' * (exc.offset or 0) + '^',
'SyntaxError: invalid syntax',
]
fail("\n".join(msglines), pytrace=False)
fail('\n'.join(msglines), pytrace=False)
except Exception as exc:
msglines = [
"Error evaluating %r condition" % mark.name,
" " + condition,
'Error evaluating %r condition' % mark.name,
' ' + condition,
*traceback.format_exception_only(type(exc), exc),
]
fail("\n".join(msglines), pytrace=False)
fail('\n'.join(msglines), pytrace=False)
# Boolean condition.
else:
@ -136,20 +138,20 @@ def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool,
result = bool(condition)
except Exception as exc:
msglines = [
"Error evaluating %r condition as a boolean" % mark.name,
'Error evaluating %r condition as a boolean' % mark.name,
*traceback.format_exception_only(type(exc), exc),
]
fail("\n".join(msglines), pytrace=False)
fail('\n'.join(msglines), pytrace=False)
reason = mark.kwargs.get("reason", None)
reason = mark.kwargs.get('reason', None)
if reason is None:
if isinstance(condition, str):
reason = "condition: " + condition
reason = 'condition: ' + condition
else:
# XXX better be checked at collection time
msg = (
"Error evaluating %r: " % mark.name
+ "you need to specify reason=STRING when using booleans as conditions."
'Error evaluating %r: ' % mark.name +
'you need to specify reason=STRING when using booleans as conditions.'
)
fail(msg, pytrace=False)
@ -160,20 +162,20 @@ def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool,
class Skip:
"""The result of evaluate_skip_marks()."""
reason: str = "unconditional skip"
reason: str = 'unconditional skip'
def evaluate_skip_marks(item: Item) -> Optional[Skip]:
def evaluate_skip_marks(item: Item) -> Skip | None:
"""Evaluate skip and skipif marks on item, returning Skip if triggered."""
for mark in item.iter_markers(name="skipif"):
if "condition" not in mark.kwargs:
for mark in item.iter_markers(name='skipif'):
if 'condition' not in mark.kwargs:
conditions = mark.args
else:
conditions = (mark.kwargs["condition"],)
conditions = (mark.kwargs['condition'],)
# Unconditional.
if not conditions:
reason = mark.kwargs.get("reason", "")
reason = mark.kwargs.get('reason', '')
return Skip(reason)
# If any of the conditions are true.
@ -182,11 +184,11 @@ def evaluate_skip_marks(item: Item) -> Optional[Skip]:
if result:
return Skip(reason)
for mark in item.iter_markers(name="skip"):
for mark in item.iter_markers(name='skip'):
try:
return Skip(*mark.args, **mark.kwargs)
except TypeError as e:
raise TypeError(str(e) + " - maybe you meant pytest.mark.skipif?") from None
raise TypeError(str(e) + ' - maybe you meant pytest.mark.skipif?') from None
return None
@ -195,28 +197,28 @@ def evaluate_skip_marks(item: Item) -> Optional[Skip]:
class Xfail:
"""The result of evaluate_xfail_marks()."""
__slots__ = ("reason", "run", "strict", "raises")
__slots__ = ('reason', 'run', 'strict', 'raises')
reason: str
run: bool
strict: bool
raises: Optional[Tuple[Type[BaseException], ...]]
raises: tuple[type[BaseException], ...] | None
def evaluate_xfail_marks(item: Item) -> Optional[Xfail]:
def evaluate_xfail_marks(item: Item) -> Xfail | None:
"""Evaluate xfail marks on item, returning Xfail if triggered."""
for mark in item.iter_markers(name="xfail"):
run = mark.kwargs.get("run", True)
strict = mark.kwargs.get("strict", item.config.getini("xfail_strict"))
raises = mark.kwargs.get("raises", None)
if "condition" not in mark.kwargs:
for mark in item.iter_markers(name='xfail'):
run = mark.kwargs.get('run', True)
strict = mark.kwargs.get('strict', item.config.getini('xfail_strict'))
raises = mark.kwargs.get('raises', None)
if 'condition' not in mark.kwargs:
conditions = mark.args
else:
conditions = (mark.kwargs["condition"],)
conditions = (mark.kwargs['condition'],)
# Unconditional.
if not conditions:
reason = mark.kwargs.get("reason", "")
reason = mark.kwargs.get('reason', '')
return Xfail(reason, run, strict, raises)
# If any of the conditions are true.
@ -240,7 +242,7 @@ def pytest_runtest_setup(item: Item) -> None:
item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item)
if xfailed and not item.config.option.runxfail and not xfailed.run:
xfail("[NOTRUN] " + xfailed.reason)
xfail('[NOTRUN] ' + xfailed.reason)
@hookimpl(wrapper=True)
@ -250,7 +252,7 @@ def pytest_runtest_call(item: Item) -> Generator[None, None, None]:
item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item)
if xfailed and not item.config.option.runxfail and not xfailed.run:
xfail("[NOTRUN] " + xfailed.reason)
xfail('[NOTRUN] ' + xfailed.reason)
try:
return (yield)
@ -263,7 +265,7 @@ def pytest_runtest_call(item: Item) -> Generator[None, None, None]:
@hookimpl(wrapper=True)
def pytest_runtest_makereport(
item: Item, call: CallInfo[None]
item: Item, call: CallInfo[None],
) -> Generator[None, TestReport, TestReport]:
rep = yield
xfailed = item.stash.get(xfailed_key, None)
@ -271,30 +273,30 @@ def pytest_runtest_makereport(
pass # don't interfere
elif call.excinfo and isinstance(call.excinfo.value, xfail.Exception):
assert call.excinfo.value.msg is not None
rep.wasxfail = "reason: " + call.excinfo.value.msg
rep.outcome = "skipped"
rep.wasxfail = 'reason: ' + call.excinfo.value.msg
rep.outcome = 'skipped'
elif not rep.skipped and xfailed:
if call.excinfo:
raises = xfailed.raises
if raises is not None and not isinstance(call.excinfo.value, raises):
rep.outcome = "failed"
rep.outcome = 'failed'
else:
rep.outcome = "skipped"
rep.outcome = 'skipped'
rep.wasxfail = xfailed.reason
elif call.when == "call":
elif call.when == 'call':
if xfailed.strict:
rep.outcome = "failed"
rep.longrepr = "[XPASS(strict)] " + xfailed.reason
rep.outcome = 'failed'
rep.longrepr = '[XPASS(strict)] ' + xfailed.reason
else:
rep.outcome = "passed"
rep.outcome = 'passed'
rep.wasxfail = xfailed.reason
return rep
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]:
if hasattr(report, "wasxfail"):
def pytest_report_teststatus(report: BaseReport) -> tuple[str, str, str] | None:
if hasattr(report, 'wasxfail'):
if report.skipped:
return "xfailed", "x", "XFAIL"
return 'xfailed', 'x', 'XFAIL'
elif report.passed:
return "xpassed", "X", "XPASS"
return 'xpassed', 'X', 'XPASS'
return None

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Any
from typing import cast
from typing import Dict
@ -6,11 +8,11 @@ from typing import TypeVar
from typing import Union
__all__ = ["Stash", "StashKey"]
__all__ = ['Stash', 'StashKey']
T = TypeVar("T")
D = TypeVar("D")
T = TypeVar('T')
D = TypeVar('D')
class StashKey(Generic[T]):
@ -63,10 +65,10 @@ class Stash:
some_bool = stash[some_bool_key]
"""
__slots__ = ("_storage",)
__slots__ = ('_storage',)
def __init__(self) -> None:
self._storage: Dict[StashKey[Any], object] = {}
self._storage: dict[StashKey[Any], object] = {}
def __setitem__(self, key: StashKey[T], value: T) -> None:
"""Set a value for key."""
@ -79,7 +81,7 @@ class Stash:
"""
return cast(T, self._storage[key])
def get(self, key: StashKey[T], default: D) -> Union[T, D]:
def get(self, key: StashKey[T], default: D) -> T | D:
"""Get the value for key, or return default if the key wasn't set
before."""
try:

View file

@ -1,39 +1,41 @@
from __future__ import annotations
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from _pytest import nodes
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.main import Session
from _pytest.reports import TestReport
import pytest
if TYPE_CHECKING:
from _pytest.cacheprovider import Cache
STEPWISE_CACHE_DIR = "cache/stepwise"
STEPWISE_CACHE_DIR = 'cache/stepwise'
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group = parser.getgroup('general')
group.addoption(
"--sw",
"--stepwise",
action="store_true",
'--sw',
'--stepwise',
action='store_true',
default=False,
dest="stepwise",
help="Exit on test failure and continue from last failing test next time",
dest='stepwise',
help='Exit on test failure and continue from last failing test next time',
)
group.addoption(
"--sw-skip",
"--stepwise-skip",
action="store_true",
'--sw-skip',
'--stepwise-skip',
action='store_true',
default=False,
dest="stepwise_skip",
help="Ignore the first failing test but stop on the next failing test. "
"Implicitly enables --stepwise.",
dest='stepwise_skip',
help='Ignore the first failing test but stop on the next failing test. '
'Implicitly enables --stepwise.',
)
@ -42,14 +44,14 @@ def pytest_configure(config: Config) -> None:
if config.option.stepwise_skip:
# allow --stepwise-skip to work on it's own merits.
config.option.stepwise = True
if config.getoption("stepwise"):
config.pluginmanager.register(StepwisePlugin(config), "stepwiseplugin")
if config.getoption('stepwise'):
config.pluginmanager.register(StepwisePlugin(config), 'stepwiseplugin')
def pytest_sessionfinish(session: Session) -> None:
if not session.config.getoption("stepwise"):
if not session.config.getoption('stepwise'):
assert session.config.cache is not None
if hasattr(session.config, "workerinput"):
if hasattr(session.config, 'workerinput'):
# Do not update cache if this process is a xdist worker to prevent
# race conditions (#10641).
return
@ -60,21 +62,21 @@ def pytest_sessionfinish(session: Session) -> None:
class StepwisePlugin:
def __init__(self, config: Config) -> None:
self.config = config
self.session: Optional[Session] = None
self.report_status = ""
self.session: Session | None = None
self.report_status = ''
assert config.cache is not None
self.cache: Cache = config.cache
self.lastfailed: Optional[str] = self.cache.get(STEPWISE_CACHE_DIR, None)
self.skip: bool = config.getoption("stepwise_skip")
self.lastfailed: str | None = self.cache.get(STEPWISE_CACHE_DIR, None)
self.skip: bool = config.getoption('stepwise_skip')
def pytest_sessionstart(self, session: Session) -> None:
self.session = session
def pytest_collection_modifyitems(
self, config: Config, items: List[nodes.Item]
self, config: Config, items: list[nodes.Item],
) -> None:
if not self.lastfailed:
self.report_status = "no previously failed tests, not skipping."
self.report_status = 'no previously failed tests, not skipping.'
return
# check all item nodes until we find a match on last failed
@ -87,9 +89,9 @@ class StepwisePlugin:
# If the previously failed test was not found among the test items,
# do not skip any tests.
if failed_index is None:
self.report_status = "previously failed test not found, not skipping."
self.report_status = 'previously failed test not found, not skipping.'
else:
self.report_status = f"skipping {failed_index} already passed items."
self.report_status = f'skipping {failed_index} already passed items.'
deselected = items[:failed_index]
del items[:failed_index]
config.hook.pytest_deselected(items=deselected)
@ -108,23 +110,23 @@ class StepwisePlugin:
self.lastfailed = report.nodeid
assert self.session is not None
self.session.shouldstop = (
"Test failed, continuing from this test next run."
'Test failed, continuing from this test next run.'
)
else:
# If the test was actually run and did pass.
if report.when == "call":
if report.when == 'call':
# Remove test from the failed ones, if exists.
if report.nodeid == self.lastfailed:
self.lastfailed = None
def pytest_report_collectionfinish(self) -> Optional[str]:
if self.config.getoption("verbose") >= 0 and self.report_status:
return f"stepwise: {self.report_status}"
def pytest_report_collectionfinish(self) -> str | None:
if self.config.getoption('verbose') >= 0 and self.report_status:
return f'stepwise: {self.report_status}'
return None
def pytest_sessionfinish(self) -> None:
if hasattr(self.config, "workerinput"):
if hasattr(self.config, 'workerinput'):
# Do not update cache if this process is a xdist worker to prevent
# race conditions (#10641).
return

File diff suppressed because it is too large Load diff

View file

@ -1,12 +1,14 @@
from __future__ import annotations
import threading
import traceback
import warnings
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Generator
from typing import Optional
from typing import Type
import warnings
import pytest
@ -34,22 +36,22 @@ class catch_threading_exception:
"""
def __init__(self) -> None:
self.args: Optional["threading.ExceptHookArgs"] = None
self._old_hook: Optional[Callable[["threading.ExceptHookArgs"], Any]] = None
self.args: threading.ExceptHookArgs | None = None
self._old_hook: Callable[[threading.ExceptHookArgs], Any] | None = None
def _hook(self, args: "threading.ExceptHookArgs") -> None:
def _hook(self, args: threading.ExceptHookArgs) -> None:
self.args = args
def __enter__(self) -> "catch_threading_exception":
def __enter__(self) -> catch_threading_exception:
self._old_hook = threading.excepthook
threading.excepthook = self._hook
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
assert self._old_hook is not None
threading.excepthook = self._old_hook
@ -64,15 +66,15 @@ def thread_exception_runtest_hook() -> Generator[None, None, None]:
finally:
if cm.args:
thread_name = (
"<unknown>" if cm.args.thread is None else cm.args.thread.name
'<unknown>' if cm.args.thread is None else cm.args.thread.name
)
msg = f"Exception in thread {thread_name}\n\n"
msg += "".join(
msg = f'Exception in thread {thread_name}\n\n'
msg += ''.join(
traceback.format_exception(
cm.args.exc_type,
cm.args.exc_value,
cm.args.exc_traceback,
)
),
)
warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg))

View file

@ -5,10 +5,11 @@ pytest runtime information (issue #185).
Fixture "mock_timing" also interacts with this module for pytest's own tests.
"""
from __future__ import annotations
from time import perf_counter
from time import sleep
from time import time
__all__ = ["perf_counter", "sleep", "time"]
__all__ = ['perf_counter', 'sleep', 'time']

View file

@ -1,11 +1,13 @@
# mypy: allow-untyped-defs
"""Support for providing temporary directories to test functions."""
from __future__ import annotations
import dataclasses
import os
from pathlib import Path
import re
from shutil import rmtree
import tempfile
from pathlib import Path
from shutil import rmtree
from typing import Any
from typing import Dict
from typing import final
@ -14,11 +16,6 @@ from typing import Literal
from typing import Optional
from typing import Union
from .pathlib import cleanup_dead_symlinks
from .pathlib import LOCK_TIMEOUT
from .pathlib import make_numbered_dir
from .pathlib import make_numbered_dir_with_cleanup
from .pathlib import rm_rf
from _pytest.compat import get_user_id
from _pytest.config import Config
from _pytest.config import ExitCode
@ -32,9 +29,15 @@ from _pytest.nodes import Item
from _pytest.reports import TestReport
from _pytest.stash import StashKey
from .pathlib import cleanup_dead_symlinks
from .pathlib import LOCK_TIMEOUT
from .pathlib import make_numbered_dir
from .pathlib import make_numbered_dir_with_cleanup
from .pathlib import rm_rf
tmppath_result_key = StashKey[Dict[str, bool]]()
RetentionType = Literal["all", "failed", "none"]
RetentionType = Literal['all', 'failed', 'none']
@final
@ -45,20 +48,20 @@ class TempPathFactory:
The base directory can be configured using the ``--basetemp`` option.
"""
_given_basetemp: Optional[Path]
_given_basetemp: Path | None
# pluggy TagTracerSub, not currently exposed, so Any.
_trace: Any
_basetemp: Optional[Path]
_basetemp: Path | None
_retention_count: int
_retention_policy: RetentionType
def __init__(
self,
given_basetemp: Optional[Path],
given_basetemp: Path | None,
retention_count: int,
retention_policy: RetentionType,
trace,
basetemp: Optional[Path] = None,
basetemp: Path | None = None,
*,
_ispytest: bool = False,
) -> None:
@ -81,27 +84,27 @@ class TempPathFactory:
config: Config,
*,
_ispytest: bool = False,
) -> "TempPathFactory":
) -> TempPathFactory:
"""Create a factory according to pytest configuration.
:meta private:
"""
check_ispytest(_ispytest)
count = int(config.getini("tmp_path_retention_count"))
count = int(config.getini('tmp_path_retention_count'))
if count < 0:
raise ValueError(
f"tmp_path_retention_count must be >= 0. Current input: {count}."
f'tmp_path_retention_count must be >= 0. Current input: {count}.',
)
policy = config.getini("tmp_path_retention_policy")
if policy not in ("all", "failed", "none"):
policy = config.getini('tmp_path_retention_policy')
if policy not in ('all', 'failed', 'none'):
raise ValueError(
f"tmp_path_retention_policy must be either all, failed, none. Current input: {policy}."
f'tmp_path_retention_policy must be either all, failed, none. Current input: {policy}.',
)
return cls(
given_basetemp=config.option.basetemp,
trace=config.trace.get("tmpdir"),
trace=config.trace.get('tmpdir'),
retention_count=count,
retention_policy=policy,
_ispytest=True,
@ -110,7 +113,7 @@ class TempPathFactory:
def _ensure_relative_to_basetemp(self, basename: str) -> str:
basename = os.path.normpath(basename)
if (self.getbasetemp() / basename).resolve().parent != self.getbasetemp():
raise ValueError(f"{basename} is not a normalized and relative path")
raise ValueError(f'{basename} is not a normalized and relative path')
return basename
def mktemp(self, basename: str, numbered: bool = True) -> Path:
@ -134,7 +137,7 @@ class TempPathFactory:
p.mkdir(mode=0o700)
else:
p = make_numbered_dir(root=self.getbasetemp(), prefix=basename, mode=0o700)
self._trace("mktemp", p)
self._trace('mktemp', p)
return p
def getbasetemp(self) -> Path:
@ -153,17 +156,17 @@ class TempPathFactory:
basetemp.mkdir(mode=0o700)
basetemp = basetemp.resolve()
else:
from_env = os.environ.get("PYTEST_DEBUG_TEMPROOT")
from_env = os.environ.get('PYTEST_DEBUG_TEMPROOT')
temproot = Path(from_env or tempfile.gettempdir()).resolve()
user = get_user() or "unknown"
user = get_user() or 'unknown'
# use a sub-directory in the temproot to speed-up
# make_numbered_dir() call
rootdir = temproot.joinpath(f"pytest-of-{user}")
rootdir = temproot.joinpath(f'pytest-of-{user}')
try:
rootdir.mkdir(mode=0o700, exist_ok=True)
except OSError:
# getuser() likely returned illegal characters for the platform, use unknown back off mechanism
rootdir = temproot.joinpath("pytest-of-unknown")
rootdir = temproot.joinpath('pytest-of-unknown')
rootdir.mkdir(mode=0o700, exist_ok=True)
# Because we use exist_ok=True with a predictable name, make sure
# we are the owners, to prevent any funny business (on unix, where
@ -176,16 +179,16 @@ class TempPathFactory:
rootdir_stat = rootdir.stat()
if rootdir_stat.st_uid != uid:
raise OSError(
f"The temporary directory {rootdir} is not owned by the current user. "
"Fix this and try again."
f'The temporary directory {rootdir} is not owned by the current user. '
'Fix this and try again.',
)
if (rootdir_stat.st_mode & 0o077) != 0:
os.chmod(rootdir, rootdir_stat.st_mode & ~0o077)
keep = self._retention_count
if self._retention_policy == "none":
if self._retention_policy == 'none':
keep = 0
basetemp = make_numbered_dir_with_cleanup(
prefix="pytest-",
prefix='pytest-',
root=rootdir,
keep=keep,
lock_timeout=LOCK_TIMEOUT,
@ -193,11 +196,11 @@ class TempPathFactory:
)
assert basetemp is not None, basetemp
self._basetemp = basetemp
self._trace("new basetemp", basetemp)
self._trace('new basetemp', basetemp)
return basetemp
def get_user() -> Optional[str]:
def get_user() -> str | None:
"""Return the current user name, or None if getuser() does not work
in the current environment (see #1010)."""
try:
@ -219,25 +222,25 @@ def pytest_configure(config: Config) -> None:
mp = MonkeyPatch()
config.add_cleanup(mp.undo)
_tmp_path_factory = TempPathFactory.from_config(config, _ispytest=True)
mp.setattr(config, "_tmp_path_factory", _tmp_path_factory, raising=False)
mp.setattr(config, '_tmp_path_factory', _tmp_path_factory, raising=False)
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"tmp_path_retention_count",
help="How many sessions should we keep the `tmp_path` directories, according to `tmp_path_retention_policy`.",
'tmp_path_retention_count',
help='How many sessions should we keep the `tmp_path` directories, according to `tmp_path_retention_policy`.',
default=3,
)
parser.addini(
"tmp_path_retention_policy",
help="Controls which directories created by the `tmp_path` fixture are kept around, based on test outcome. "
"(all/failed/none)",
default="all",
'tmp_path_retention_policy',
help='Controls which directories created by the `tmp_path` fixture are kept around, based on test outcome. '
'(all/failed/none)',
default='all',
)
@fixture(scope="session")
@fixture(scope='session')
def tmp_path_factory(request: FixtureRequest) -> TempPathFactory:
"""Return a :class:`pytest.TempPathFactory` instance for the test session."""
# Set dynamically by pytest_configure() above.
@ -246,7 +249,7 @@ def tmp_path_factory(request: FixtureRequest) -> TempPathFactory:
def _mk_tmp(request: FixtureRequest, factory: TempPathFactory) -> Path:
name = request.node.name
name = re.sub(r"[\W]", "_", name)
name = re.sub(r'[\W]', '_', name)
MAXVAL = 30
name = name[:MAXVAL]
return factory.mktemp(name, numbered=True)
@ -254,7 +257,7 @@ def _mk_tmp(request: FixtureRequest, factory: TempPathFactory) -> Path:
@fixture
def tmp_path(
request: FixtureRequest, tmp_path_factory: TempPathFactory
request: FixtureRequest, tmp_path_factory: TempPathFactory,
) -> Generator[Path, None, None]:
"""Return a temporary directory path object which is unique to each test
function invocation, created as a sub directory of the base temporary
@ -277,7 +280,7 @@ def tmp_path(
policy = tmp_path_factory._retention_policy
result_dict = request.node.stash[tmppath_result_key]
if policy == "failed" and result_dict.get("call", True):
if policy == 'failed' and result_dict.get('call', True):
# We do a "best effort" to remove files, but it might not be possible due to some leaked resource,
# permissions, etc, in which case we ignore it.
rmtree(path, ignore_errors=True)
@ -285,7 +288,7 @@ def tmp_path(
del request.node.stash[tmppath_result_key]
def pytest_sessionfinish(session, exitstatus: Union[int, ExitCode]):
def pytest_sessionfinish(session, exitstatus: int | ExitCode):
"""After each session, remove base directory if all the tests passed,
the policy is "failed", and the basetemp is not specified by a user.
"""
@ -296,9 +299,9 @@ def pytest_sessionfinish(session, exitstatus: Union[int, ExitCode]):
policy = tmp_path_factory._retention_policy
if (
exitstatus == 0
and policy == "failed"
and tmp_path_factory._given_basetemp is None
exitstatus == 0 and
policy == 'failed' and
tmp_path_factory._given_basetemp is None
):
if basetemp.is_dir():
# We do a "best effort" to remove files, but it might not be possible due to some leaked resource,
@ -312,10 +315,10 @@ def pytest_sessionfinish(session, exitstatus: Union[int, ExitCode]):
@hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_makereport(
item: Item, call
item: Item, call,
) -> Generator[None, TestReport, TestReport]:
rep = yield
assert rep.when is not None
empty: Dict[str, bool] = {}
empty: dict[str, bool] = {}
item.stash.setdefault(tmppath_result_key, empty)[rep.when] = rep.passed
return rep

View file

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs
"""Discover and run std-library "unittest" style tests."""
from __future__ import annotations
import sys
import traceback
import types
@ -15,6 +17,7 @@ from typing import TYPE_CHECKING
from typing import Union
import _pytest._code
import pytest
from _pytest.compat import getimfunc
from _pytest.compat import is_async_function
from _pytest.config import hookimpl
@ -29,7 +32,6 @@ from _pytest.python import Class
from _pytest.python import Function
from _pytest.python import Module
from _pytest.runner import CallInfo
import pytest
if TYPE_CHECKING:
@ -44,11 +46,11 @@ if TYPE_CHECKING:
def pytest_pycollect_makeitem(
collector: Union[Module, Class], name: str, obj: object
) -> Optional["UnitTestCase"]:
collector: Module | Class, name: str, obj: object,
) -> UnitTestCase | None:
# Has unittest been imported and is obj a subclass of its TestCase?
try:
ut = sys.modules["unittest"]
ut = sys.modules['unittest']
# Type ignored because `ut` is an opaque module.
if not issubclass(obj, ut.TestCase): # type: ignore
return None
@ -63,11 +65,11 @@ class UnitTestCase(Class):
# to declare that our children do not support funcargs.
nofuncargs = True
def collect(self) -> Iterable[Union[Item, Collector]]:
def collect(self) -> Iterable[Item | Collector]:
from unittest import TestLoader
cls = self.obj
if not getattr(cls, "__test__", True):
if not getattr(cls, '__test__', True):
return
skipped = _is_skipped(cls)
@ -81,28 +83,28 @@ class UnitTestCase(Class):
foundsomething = False
for name in loader.getTestCaseNames(self.obj):
x = getattr(self.obj, name)
if not getattr(x, "__test__", True):
if not getattr(x, '__test__', True):
continue
funcobj = getimfunc(x)
yield TestCaseFunction.from_parent(self, name=name, callobj=funcobj)
foundsomething = True
if not foundsomething:
runtest = getattr(self.obj, "runTest", None)
runtest = getattr(self.obj, 'runTest', None)
if runtest is not None:
ut = sys.modules.get("twisted.trial.unittest", None)
ut = sys.modules.get('twisted.trial.unittest', None)
# Type ignored because `ut` is an opaque module.
if ut is None or runtest != ut.TestCase.runTest: # type: ignore
yield TestCaseFunction.from_parent(self, name="runTest")
yield TestCaseFunction.from_parent(self, name='runTest')
def _register_unittest_setup_class_fixture(self, cls: type) -> None:
"""Register an auto-use fixture to invoke setUpClass and
tearDownClass (#517)."""
setup = getattr(cls, "setUpClass", None)
teardown = getattr(cls, "tearDownClass", None)
setup = getattr(cls, 'setUpClass', None)
teardown = getattr(cls, 'tearDownClass', None)
if setup is None and teardown is None:
return None
cleanup = getattr(cls, "doClassCleanups", lambda: None)
cleanup = getattr(cls, 'doClassCleanups', lambda: None)
def unittest_setup_class_fixture(
request: FixtureRequest,
@ -128,18 +130,18 @@ class UnitTestCase(Class):
self.session._fixturemanager._register_fixture(
# Use a unique name to speed up lookup.
name=f"_unittest_setUpClass_fixture_{cls.__qualname__}",
name=f'_unittest_setUpClass_fixture_{cls.__qualname__}',
func=unittest_setup_class_fixture,
nodeid=self.nodeid,
scope="class",
scope='class',
autouse=True,
)
def _register_unittest_setup_method_fixture(self, cls: type) -> None:
"""Register an auto-use fixture to invoke setup_method and
teardown_method (#517)."""
setup = getattr(cls, "setup_method", None)
teardown = getattr(cls, "teardown_method", None)
setup = getattr(cls, 'setup_method', None)
teardown = getattr(cls, 'teardown_method', None)
if setup is None and teardown is None:
return None
@ -158,18 +160,18 @@ class UnitTestCase(Class):
self.session._fixturemanager._register_fixture(
# Use a unique name to speed up lookup.
name=f"_unittest_setup_method_fixture_{cls.__qualname__}",
name=f'_unittest_setup_method_fixture_{cls.__qualname__}',
func=unittest_setup_method_fixture,
nodeid=self.nodeid,
scope="function",
scope='function',
autouse=True,
)
class TestCaseFunction(Function):
nofuncargs = True
_excinfo: Optional[List[_pytest._code.ExceptionInfo[BaseException]]] = None
_testcase: Optional["unittest.TestCase"] = None
_excinfo: list[_pytest._code.ExceptionInfo[BaseException]] | None = None
_testcase: unittest.TestCase | None = None
def _getobj(self):
assert self.parent is not None
@ -182,7 +184,7 @@ class TestCaseFunction(Function):
def setup(self) -> None:
# A bound method to be called during teardown() if set (see 'runtest()').
self._explicit_tearDown: Optional[Callable[[], None]] = None
self._explicit_tearDown: Callable[[], None] | None = None
assert self.parent is not None
self._testcase = self.parent.obj(self.name) # type: ignore[attr-defined]
self._obj = getattr(self._testcase, self.name)
@ -196,15 +198,15 @@ class TestCaseFunction(Function):
self._testcase = None
self._obj = None
def startTest(self, testcase: "unittest.TestCase") -> None:
def startTest(self, testcase: unittest.TestCase) -> None:
pass
def _addexcinfo(self, rawexcinfo: "_SysExcInfoType") -> None:
def _addexcinfo(self, rawexcinfo: _SysExcInfoType) -> None:
# Unwrap potential exception info (see twisted trial support below).
rawexcinfo = getattr(rawexcinfo, "_rawexcinfo", rawexcinfo)
rawexcinfo = getattr(rawexcinfo, '_rawexcinfo', rawexcinfo)
try:
excinfo = _pytest._code.ExceptionInfo[BaseException].from_exc_info(
rawexcinfo # type: ignore[arg-type]
rawexcinfo, # type: ignore[arg-type]
)
# Invoke the attributes to trigger storing the traceback
# trial causes some issue there.
@ -216,26 +218,26 @@ class TestCaseFunction(Function):
values = traceback.format_exception(*rawexcinfo)
values.insert(
0,
"NOTE: Incompatible Exception Representation, "
"displaying natively:\n\n",
'NOTE: Incompatible Exception Representation, '
'displaying natively:\n\n',
)
fail("".join(values), pytrace=False)
fail(''.join(values), pytrace=False)
except (fail.Exception, KeyboardInterrupt):
raise
except BaseException:
fail(
"ERROR: Unknown Incompatible Exception "
f"representation:\n{rawexcinfo!r}",
'ERROR: Unknown Incompatible Exception '
f'representation:\n{rawexcinfo!r}',
pytrace=False,
)
except KeyboardInterrupt:
raise
except fail.Exception:
excinfo = _pytest._code.ExceptionInfo.from_current()
self.__dict__.setdefault("_excinfo", []).append(excinfo)
self.__dict__.setdefault('_excinfo', []).append(excinfo)
def addError(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType"
self, testcase: unittest.TestCase, rawexcinfo: _SysExcInfoType,
) -> None:
try:
if isinstance(rawexcinfo[1], exit.Exception):
@ -245,11 +247,11 @@ class TestCaseFunction(Function):
self._addexcinfo(rawexcinfo)
def addFailure(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType"
self, testcase: unittest.TestCase, rawexcinfo: _SysExcInfoType,
) -> None:
self._addexcinfo(rawexcinfo)
def addSkip(self, testcase: "unittest.TestCase", reason: str) -> None:
def addSkip(self, testcase: unittest.TestCase, reason: str) -> None:
try:
raise pytest.skip.Exception(reason, _use_item_location=True)
except skip.Exception:
@ -257,9 +259,9 @@ class TestCaseFunction(Function):
def addExpectedFailure(
self,
testcase: "unittest.TestCase",
rawexcinfo: "_SysExcInfoType",
reason: str = "",
testcase: unittest.TestCase,
rawexcinfo: _SysExcInfoType,
reason: str = '',
) -> None:
try:
xfail(str(reason))
@ -268,25 +270,25 @@ class TestCaseFunction(Function):
def addUnexpectedSuccess(
self,
testcase: "unittest.TestCase",
reason: Optional["twisted.trial.unittest.Todo"] = None,
testcase: unittest.TestCase,
reason: twisted.trial.unittest.Todo | None = None,
) -> None:
msg = "Unexpected success"
msg = 'Unexpected success'
if reason:
msg += f": {reason.reason}"
msg += f': {reason.reason}'
# Preserve unittest behaviour - fail the test. Explicitly not an XPASS.
try:
fail(msg, pytrace=False)
except fail.Exception:
self._addexcinfo(sys.exc_info())
def addSuccess(self, testcase: "unittest.TestCase") -> None:
def addSuccess(self, testcase: unittest.TestCase) -> None:
pass
def stopTest(self, testcase: "unittest.TestCase") -> None:
def stopTest(self, testcase: unittest.TestCase) -> None:
pass
def addDuration(self, testcase: "unittest.TestCase", elapsed: float) -> None:
def addDuration(self, testcase: unittest.TestCase, elapsed: float) -> None:
pass
def runtest(self) -> None:
@ -310,9 +312,9 @@ class TestCaseFunction(Function):
# We need to consider if the test itself is skipped, or the whole class.
assert isinstance(self.parent, UnitTestCase)
skipped = _is_skipped(self.obj) or _is_skipped(self.parent.obj)
if self.config.getoption("usepdb") and not skipped:
if self.config.getoption('usepdb') and not skipped:
self._explicit_tearDown = self._testcase.tearDown
setattr(self._testcase, "tearDown", lambda *args: None)
setattr(self._testcase, 'tearDown', lambda *args: None)
# We need to update the actual bound method with self.obj, because
# wrap_pytest_function_for_tracing replaces self.obj by a wrapper.
@ -323,11 +325,11 @@ class TestCaseFunction(Function):
delattr(self._testcase, self.name)
def _traceback_filter(
self, excinfo: _pytest._code.ExceptionInfo[BaseException]
self, excinfo: _pytest._code.ExceptionInfo[BaseException],
) -> _pytest._code.Traceback:
traceback = super()._traceback_filter(excinfo)
ntraceback = traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest"),
lambda x: not x.frame.f_globals.get('__unittest'),
)
if not ntraceback:
ntraceback = traceback
@ -348,13 +350,13 @@ def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> None:
# This is actually only needed for nose, which reuses unittest.SkipTest for
# its own nose.SkipTest. For unittest TestCases, SkipTest is already
# handled internally, and doesn't reach here.
unittest = sys.modules.get("unittest")
unittest = sys.modules.get('unittest')
if (
unittest and call.excinfo and isinstance(call.excinfo.value, unittest.SkipTest) # type: ignore[attr-defined]
):
excinfo = call.excinfo
call2 = CallInfo[None].from_call(
lambda: pytest.skip(str(excinfo.value)), call.when
lambda: pytest.skip(str(excinfo.value)), call.when,
)
call.excinfo = call2.excinfo
@ -365,8 +367,8 @@ classImplements_has_run = False
@hookimpl(wrapper=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
if isinstance(item, TestCaseFunction) and "twisted.trial.unittest" in sys.modules:
ut: Any = sys.modules["twisted.python.failure"]
if isinstance(item, TestCaseFunction) and 'twisted.trial.unittest' in sys.modules:
ut: Any = sys.modules['twisted.python.failure']
global classImplements_has_run
Failure__init__ = ut.Failure.__init__
if not classImplements_has_run:
@ -377,7 +379,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
classImplements_has_run = True
def excstore(
self, exc_value=None, exc_type=None, exc_tb=None, captureVars=None
self, exc_value=None, exc_type=None, exc_tb=None, captureVars=None,
):
if exc_value is None:
self._rawexcinfo = sys.exc_info()
@ -387,7 +389,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
self._rawexcinfo = (exc_type, exc_value, exc_tb)
try:
Failure__init__(
self, exc_value, exc_type, exc_tb, captureVars=captureVars
self, exc_value, exc_type, exc_tb, captureVars=captureVars,
)
except TypeError:
Failure__init__(self, exc_value, exc_type, exc_tb)
@ -404,4 +406,4 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
def _is_skipped(obj) -> bool:
"""Return True if the given object has been marked with @unittest.skip."""
return bool(getattr(obj, "__unittest_skip__", False))
return bool(getattr(obj, '__unittest_skip__', False))

View file

@ -1,12 +1,14 @@
from __future__ import annotations
import sys
import traceback
import warnings
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Generator
from typing import Optional
from typing import Type
import warnings
import pytest
@ -34,24 +36,24 @@ class catch_unraisable_exception:
"""
def __init__(self) -> None:
self.unraisable: Optional["sys.UnraisableHookArgs"] = None
self._old_hook: Optional[Callable[["sys.UnraisableHookArgs"], Any]] = None
self.unraisable: sys.UnraisableHookArgs | None = None
self._old_hook: Callable[[sys.UnraisableHookArgs], Any] | None = None
def _hook(self, unraisable: "sys.UnraisableHookArgs") -> None:
def _hook(self, unraisable: sys.UnraisableHookArgs) -> None:
# Storing unraisable.object can resurrect an object which is being
# finalized. Storing unraisable.exc_value creates a reference cycle.
self.unraisable = unraisable
def __enter__(self) -> "catch_unraisable_exception":
def __enter__(self) -> catch_unraisable_exception:
self._old_hook = sys.unraisablehook
sys.unraisablehook = self._hook
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
assert self._old_hook is not None
sys.unraisablehook = self._old_hook
@ -68,14 +70,14 @@ def unraisable_exception_runtest_hook() -> Generator[None, None, None]:
if cm.unraisable.err_msg is not None:
err_msg = cm.unraisable.err_msg
else:
err_msg = "Exception ignored in"
msg = f"{err_msg}: {cm.unraisable.object!r}\n\n"
msg += "".join(
err_msg = 'Exception ignored in'
msg = f'{err_msg}: {cm.unraisable.object!r}\n\n'
msg += ''.join(
traceback.format_exception(
cm.unraisable.exc_type,
cm.unraisable.exc_value,
cm.unraisable.exc_traceback,
)
),
)
warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))

View file

@ -1,64 +1,66 @@
from __future__ import annotations
import dataclasses
import inspect
import warnings
from types import FunctionType
from typing import Any
from typing import final
from typing import Generic
from typing import Type
from typing import TypeVar
import warnings
class PytestWarning(UserWarning):
"""Base class for all warnings emitted by pytest."""
__module__ = "pytest"
__module__ = 'pytest'
@final
class PytestAssertRewriteWarning(PytestWarning):
"""Warning emitted by the pytest assert rewrite module."""
__module__ = "pytest"
__module__ = 'pytest'
@final
class PytestCacheWarning(PytestWarning):
"""Warning emitted by the cache plugin in various situations."""
__module__ = "pytest"
__module__ = 'pytest'
@final
class PytestConfigWarning(PytestWarning):
"""Warning emitted for configuration issues."""
__module__ = "pytest"
__module__ = 'pytest'
@final
class PytestCollectionWarning(PytestWarning):
"""Warning emitted when pytest is not able to collect a file or symbol in a module."""
__module__ = "pytest"
__module__ = 'pytest'
class PytestDeprecationWarning(PytestWarning, DeprecationWarning):
"""Warning class for features that will be removed in a future version."""
__module__ = "pytest"
__module__ = 'pytest'
class PytestRemovedIn9Warning(PytestDeprecationWarning):
"""Warning class for features that will be removed in pytest 9."""
__module__ = "pytest"
__module__ = 'pytest'
class PytestReturnNotNoneWarning(PytestWarning):
"""Warning emitted when a test function is returning value other than None."""
__module__ = "pytest"
__module__ = 'pytest'
@final
@ -69,11 +71,11 @@ class PytestExperimentalApiWarning(PytestWarning, FutureWarning):
future version.
"""
__module__ = "pytest"
__module__ = 'pytest'
@classmethod
def simple(cls, apiname: str) -> "PytestExperimentalApiWarning":
return cls(f"{apiname} is an experimental api that may change over time")
def simple(cls, apiname: str) -> PytestExperimentalApiWarning:
return cls(f'{apiname} is an experimental api that may change over time')
@final
@ -85,7 +87,7 @@ class PytestUnhandledCoroutineWarning(PytestReturnNotNoneWarning):
Coroutine test functions are not natively supported.
"""
__module__ = "pytest"
__module__ = 'pytest'
@final
@ -95,7 +97,7 @@ class PytestUnknownMarkWarning(PytestWarning):
See :ref:`mark` for details.
"""
__module__ = "pytest"
__module__ = 'pytest'
@final
@ -107,7 +109,7 @@ class PytestUnraisableExceptionWarning(PytestWarning):
as normal.
"""
__module__ = "pytest"
__module__ = 'pytest'
@final
@ -117,10 +119,10 @@ class PytestUnhandledThreadExceptionWarning(PytestWarning):
Such exceptions don't propagate normally.
"""
__module__ = "pytest"
__module__ = 'pytest'
_W = TypeVar("_W", bound=PytestWarning)
_W = TypeVar('_W', bound=PytestWarning)
@final
@ -132,7 +134,7 @@ class UnformattedWarning(Generic[_W]):
as opposed to a direct message.
"""
category: Type["_W"]
category: type[_W]
template: str
def format(self, **kwargs: Any) -> _W:
@ -157,9 +159,9 @@ def warn_explicit_for(method: FunctionType, message: PytestWarning) -> None:
type(message),
filename=filename,
module=module,
registry=mod_globals.setdefault("__warningregistry__", {}),
registry=mod_globals.setdefault('__warningregistry__', {}),
lineno=lineno,
)
except Warning as w:
# If warnings are errors (e.g. -Werror), location information gets lost, so we add it to the message.
raise type(w)(f"{w}\n at {filename}:{lineno}") from None
raise type(w)(f'{w}\n at {filename}:{lineno}') from None

View file

@ -1,25 +1,27 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from __future__ import annotations
import sys
import warnings
from contextlib import contextmanager
from typing import Generator
from typing import Literal
from typing import Optional
import warnings
import pytest
from _pytest.config import apply_warning_filters
from _pytest.config import Config
from _pytest.config import parse_warning_filter
from _pytest.main import Session
from _pytest.nodes import Item
from _pytest.terminal import TerminalReporter
import pytest
def pytest_configure(config: Config) -> None:
config.addinivalue_line(
"markers",
"filterwarnings(warning): add a warning filter to the given test. "
"see https://docs.pytest.org/en/stable/how-to/capture-warnings.html#pytest-mark-filterwarnings ",
'markers',
'filterwarnings(warning): add a warning filter to the given test. '
'see https://docs.pytest.org/en/stable/how-to/capture-warnings.html#pytest-mark-filterwarnings ',
)
@ -27,8 +29,8 @@ def pytest_configure(config: Config) -> None:
def catch_warnings_for_item(
config: Config,
ihook,
when: Literal["config", "collect", "runtest"],
item: Optional[Item],
when: Literal['config', 'collect', 'runtest'],
item: Item | None,
) -> Generator[None, None, None]:
"""Context manager that catches warnings generated in the contained execution block.
@ -36,7 +38,7 @@ def catch_warnings_for_item(
Each warning captured triggers the ``pytest_warning_recorded`` hook.
"""
config_filters = config.getini("filterwarnings")
config_filters = config.getini('filterwarnings')
cmdline_filters = config.known_args_namespace.pythonwarnings or []
with warnings.catch_warnings(record=True) as log:
# mypy can't infer that record=True means log is not None; help it.
@ -44,8 +46,8 @@ def catch_warnings_for_item(
if not sys.warnoptions:
# If user is not explicitly configuring warning filters, show deprecation warnings by default (#2908).
warnings.filterwarnings("always", category=DeprecationWarning)
warnings.filterwarnings("always", category=PendingDeprecationWarning)
warnings.filterwarnings('always', category=DeprecationWarning)
warnings.filterwarnings('always', category=PendingDeprecationWarning)
# To be enabled in pytest 9.0.0.
# warnings.filterwarnings("error", category=pytest.PytestRemovedIn9Warning)
@ -53,9 +55,9 @@ def catch_warnings_for_item(
apply_warning_filters(config_filters, cmdline_filters)
# apply filters from "filterwarnings" marks
nodeid = "" if item is None else item.nodeid
nodeid = '' if item is None else item.nodeid
if item is not None:
for mark in item.iter_markers(name="filterwarnings"):
for mark in item.iter_markers(name='filterwarnings'):
for arg in mark.args:
warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
@ -69,7 +71,7 @@ def catch_warnings_for_item(
nodeid=nodeid,
when=when,
location=None,
)
),
)
@ -91,22 +93,22 @@ def warning_record_to_str(warning_message: warnings.WarningMessage) -> str:
else:
tb = tracemalloc.get_object_traceback(warning_message.source)
if tb is not None:
formatted_tb = "\n".join(tb.format())
formatted_tb = '\n'.join(tb.format())
# Use a leading new line to better separate the (large) output
# from the traceback to the previous warning text.
msg += f"\nObject allocated at:\n{formatted_tb}"
msg += f'\nObject allocated at:\n{formatted_tb}'
else:
# No need for a leading new line.
url = "https://docs.pytest.org/en/stable/how-to/capture-warnings.html#resource-warnings"
msg += "Enable tracemalloc to get traceback where the object was allocated.\n"
msg += f"See {url} for more info."
url = 'https://docs.pytest.org/en/stable/how-to/capture-warnings.html#resource-warnings'
msg += 'Enable tracemalloc to get traceback where the object was allocated.\n'
msg += f'See {url} for more info.'
return msg
@pytest.hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
with catch_warnings_for_item(
config=item.config, ihook=item.ihook, when="runtest", item=item
config=item.config, ihook=item.ihook, when='runtest', item=item,
):
return (yield)
@ -115,7 +117,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
def pytest_collection(session: Session) -> Generator[None, object, object]:
config = session.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="collect", item=None
config=config, ihook=config.hook, when='collect', item=None,
):
return (yield)
@ -126,7 +128,7 @@ def pytest_terminal_summary(
) -> Generator[None, None, None]:
config = terminalreporter.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
config=config, ihook=config.hook, when='config', item=None,
):
return (yield)
@ -135,16 +137,16 @@ def pytest_terminal_summary(
def pytest_sessionfinish(session: Session) -> Generator[None, None, None]:
config = session.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
config=config, ihook=config.hook, when='config', item=None,
):
return (yield)
@pytest.hookimpl(wrapper=True)
def pytest_load_initial_conftests(
early_config: "Config",
early_config: Config,
) -> Generator[None, None, None]:
with catch_warnings_for_item(
config=early_config, ihook=early_config.hook, when="config", item=None
config=early_config, ihook=early_config.hook, when='config', item=None,
):
return (yield)

View file

@ -3,4 +3,3 @@ Generator: bdist_wheel (0.37.1)
Root-Is-Purelib: true
Tag: py2-none-any
Tag: py3-none-any

View file

@ -2,4 +2,3 @@ Wheel-Version: 1.0
Generator: bdist_wheel (0.43.0)
Root-Is-Purelib: false
Tag: cp310-cp310-macosx_10_9_x86_64

View file

@ -1,6 +1,5 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""
Code coverage measurement for Python.
@ -8,31 +7,22 @@ Ned Batchelder
https://coverage.readthedocs.io
"""
from __future__ import annotations
from coverage.control import Coverage as Coverage
from coverage.control import process_startup as process_startup
from coverage.data import CoverageData as CoverageData
from coverage.exceptions import CoverageException as CoverageException
from coverage.plugin import CoveragePlugin as CoveragePlugin
from coverage.plugin import FileReporter as FileReporter
from coverage.plugin import FileTracer as FileTracer
from coverage.version import __version__ as __version__
from coverage.version import version_info as version_info
# mypy's convention is that "import as" names are public from the module.
# We import names as themselves to indicate that. Pylint sees it as pointless,
# so disable its warning.
# pylint: disable=useless-import-alias
from coverage.version import (
__version__ as __version__,
version_info as version_info,
)
from coverage.control import (
Coverage as Coverage,
process_startup as process_startup,
)
from coverage.data import CoverageData as CoverageData
from coverage.exceptions import CoverageException as CoverageException
from coverage.plugin import (
CoveragePlugin as CoveragePlugin,
FileReporter as FileReporter,
FileTracer as FileTracer,
)
# Backward compatibility.
coverage = Coverage

View file

@ -1,10 +1,9 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Coverage.py's main entry point."""
from __future__ import annotations
import sys
from coverage.cmdline import main
sys.exit(main())

View file

@ -1,17 +1,16 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Source file annotation for coverage.py."""
from __future__ import annotations
import os
import re
from typing import Iterable, TYPE_CHECKING
from typing import Iterable
from typing import TYPE_CHECKING
from coverage.files import flat_rootname
from coverage.misc import ensure_dir, isolate_module
from coverage.misc import ensure_dir
from coverage.misc import isolate_module
from coverage.plugin import FileReporter
from coverage.report_core import get_analysis_to_report
from coverage.results import Analysis
@ -50,8 +49,8 @@ class AnnotateReporter:
self.config = self.coverage.config
self.directory: str | None = None
blank_re = re.compile(r"\s*(#|$)")
else_re = re.compile(r"\s*else\s*:\s*(#|$)")
blank_re = re.compile(r'\s*(#|$)')
else_re = re.compile(r'\s*else\s*:\s*(#|$)')
def report(self, morfs: Iterable[TMorf] | None, directory: str | None = None) -> None:
"""Run the report.
@ -77,13 +76,13 @@ class AnnotateReporter:
if self.directory:
ensure_dir(self.directory)
dest_file = os.path.join(self.directory, flat_rootname(fr.relative_filename()))
if dest_file.endswith("_py"):
dest_file = dest_file[:-3] + ".py"
dest_file += ",cover"
if dest_file.endswith('_py'):
dest_file = dest_file[:-3] + '.py'
dest_file += ',cover'
else:
dest_file = fr.filename + ",cover"
dest_file = fr.filename + ',cover'
with open(dest_file, "w", encoding="utf-8") as dest:
with open(dest_file, 'w', encoding='utf-8') as dest:
i = j = 0
covered = True
source = fr.source()
@ -95,20 +94,20 @@ class AnnotateReporter:
if i < len(statements) and statements[i] == lineno:
covered = j >= len(missing) or missing[j] > lineno
if self.blank_re.match(line):
dest.write(" ")
dest.write(' ')
elif self.else_re.match(line):
# Special logic for lines containing only "else:".
if j >= len(missing):
dest.write("> ")
dest.write('> ')
elif statements[i] == missing[j]:
dest.write("! ")
dest.write('! ')
else:
dest.write("> ")
dest.write('> ')
elif lineno in excluded:
dest.write("- ")
dest.write('- ')
elif covered:
dest.write("> ")
dest.write('> ')
else:
dest.write("! ")
dest.write('! ')
dest.write(line)

View file

@ -1,8 +1,6 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Bytecode manipulation for coverage.py"""
from __future__ import annotations
from types import CodeType

View file

@ -1,20 +1,18 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Command-line support for coverage.py."""
from __future__ import annotations
import glob
import optparse # pylint: disable=deprecated-module
import os
import os.path
import shlex
import sys
import textwrap
import traceback
from typing import cast, Any, NoReturn
from typing import Any
from typing import cast
from typing import NoReturn
import coverage
from coverage import Coverage
@ -22,16 +20,23 @@ from coverage import env
from coverage.collector import HAS_CTRACER
from coverage.config import CoverageConfig
from coverage.control import DEFAULT_DATAFILE
from coverage.data import combinable_files, debug_data_file
from coverage.debug import info_header, short_stack, write_formatted_info
from coverage.exceptions import _BaseCoverageException, _ExceptionDuringRun, NoSource
from coverage.data import combinable_files
from coverage.data import debug_data_file
from coverage.debug import info_header
from coverage.debug import short_stack
from coverage.debug import write_formatted_info
from coverage.exceptions import _BaseCoverageException
from coverage.exceptions import _ExceptionDuringRun
from coverage.exceptions import NoSource
from coverage.execfile import PyRunner
from coverage.results import Numbers, should_fail_under
from coverage.results import Numbers
from coverage.results import should_fail_under
from coverage.version import __url__
# When adding to this file, alphabetization is important. Look for
# "alphabetize" comments throughout.
class Opts:
"""A namespace class for individual options we'll build parsers from."""
@ -39,193 +44,193 @@ class Opts:
# appears on the command line.
append = optparse.make_option(
"-a", "--append", action="store_true",
help="Append coverage data to .coverage, otherwise it starts clean each time.",
'-a', '--append', action='store_true',
help='Append coverage data to .coverage, otherwise it starts clean each time.',
)
branch = optparse.make_option(
"", "--branch", action="store_true",
help="Measure branch coverage in addition to statement coverage.",
'', '--branch', action='store_true',
help='Measure branch coverage in addition to statement coverage.',
)
concurrency = optparse.make_option(
"", "--concurrency", action="store", metavar="LIBS",
'', '--concurrency', action='store', metavar='LIBS',
help=(
"Properly measure code using a concurrency library. " +
"Valid values are: {}, or a comma-list of them."
).format(", ".join(sorted(CoverageConfig.CONCURRENCY_CHOICES))),
'Properly measure code using a concurrency library. ' +
'Valid values are: {}, or a comma-list of them.'
).format(', '.join(sorted(CoverageConfig.CONCURRENCY_CHOICES))),
)
context = optparse.make_option(
"", "--context", action="store", metavar="LABEL",
help="The context label to record for this coverage run.",
'', '--context', action='store', metavar='LABEL',
help='The context label to record for this coverage run.',
)
contexts = optparse.make_option(
"", "--contexts", action="store", metavar="REGEX1,REGEX2,...",
'', '--contexts', action='store', metavar='REGEX1,REGEX2,...',
help=(
"Only display data from lines covered in the given contexts. " +
"Accepts Python regexes, which must be quoted."
'Only display data from lines covered in the given contexts. ' +
'Accepts Python regexes, which must be quoted.'
),
)
datafile = optparse.make_option(
"", "--data-file", action="store", metavar="DATAFILE",
'', '--data-file', action='store', metavar='DATAFILE',
help=(
"Base name of the data files to operate on. " +
'Base name of the data files to operate on. ' +
"Defaults to '.coverage'. [env: COVERAGE_FILE]"
),
)
datafle_input = optparse.make_option(
"", "--data-file", action="store", metavar="INFILE",
'', '--data-file', action='store', metavar='INFILE',
help=(
"Read coverage data for report generation from this file. " +
'Read coverage data for report generation from this file. ' +
"Defaults to '.coverage'. [env: COVERAGE_FILE]"
),
)
datafile_output = optparse.make_option(
"", "--data-file", action="store", metavar="OUTFILE",
'', '--data-file', action='store', metavar='OUTFILE',
help=(
"Write the recorded coverage data to this file. " +
'Write the recorded coverage data to this file. ' +
"Defaults to '.coverage'. [env: COVERAGE_FILE]"
),
)
debug = optparse.make_option(
"", "--debug", action="store", metavar="OPTS",
help="Debug options, separated by commas. [env: COVERAGE_DEBUG]",
'', '--debug', action='store', metavar='OPTS',
help='Debug options, separated by commas. [env: COVERAGE_DEBUG]',
)
directory = optparse.make_option(
"-d", "--directory", action="store", metavar="DIR",
help="Write the output files to DIR.",
'-d', '--directory', action='store', metavar='DIR',
help='Write the output files to DIR.',
)
fail_under = optparse.make_option(
"", "--fail-under", action="store", metavar="MIN", type="float",
help="Exit with a status of 2 if the total coverage is less than MIN.",
'', '--fail-under', action='store', metavar='MIN', type='float',
help='Exit with a status of 2 if the total coverage is less than MIN.',
)
format = optparse.make_option(
"", "--format", action="store", metavar="FORMAT",
help="Output format, either text (default), markdown, or total.",
'', '--format', action='store', metavar='FORMAT',
help='Output format, either text (default), markdown, or total.',
)
help = optparse.make_option(
"-h", "--help", action="store_true",
help="Get help on this command.",
'-h', '--help', action='store_true',
help='Get help on this command.',
)
ignore_errors = optparse.make_option(
"-i", "--ignore-errors", action="store_true",
help="Ignore errors while reading source files.",
'-i', '--ignore-errors', action='store_true',
help='Ignore errors while reading source files.',
)
include = optparse.make_option(
"", "--include", action="store", metavar="PAT1,PAT2,...",
'', '--include', action='store', metavar='PAT1,PAT2,...',
help=(
"Include only files whose paths match one of these patterns. " +
"Accepts shell-style wildcards, which must be quoted."
'Include only files whose paths match one of these patterns. ' +
'Accepts shell-style wildcards, which must be quoted.'
),
)
keep = optparse.make_option(
"", "--keep", action="store_true",
help="Keep original coverage files, otherwise they are deleted.",
'', '--keep', action='store_true',
help='Keep original coverage files, otherwise they are deleted.',
)
pylib = optparse.make_option(
"-L", "--pylib", action="store_true",
'-L', '--pylib', action='store_true',
help=(
"Measure coverage even inside the Python installed library, " +
'Measure coverage even inside the Python installed library, ' +
"which isn't done by default."
),
)
show_missing = optparse.make_option(
"-m", "--show-missing", action="store_true",
'-m', '--show-missing', action='store_true',
help="Show line numbers of statements in each module that weren't executed.",
)
module = optparse.make_option(
"-m", "--module", action="store_true",
'-m', '--module', action='store_true',
help=(
"<pyfile> is an importable Python module, not a script path, " +
'<pyfile> is an importable Python module, not a script path, ' +
"to be run as 'python -m' would run it."
),
)
omit = optparse.make_option(
"", "--omit", action="store", metavar="PAT1,PAT2,...",
'', '--omit', action='store', metavar='PAT1,PAT2,...',
help=(
"Omit files whose paths match one of these patterns. " +
"Accepts shell-style wildcards, which must be quoted."
'Omit files whose paths match one of these patterns. ' +
'Accepts shell-style wildcards, which must be quoted.'
),
)
output_xml = optparse.make_option(
"-o", "", action="store", dest="outfile", metavar="OUTFILE",
'-o', '', action='store', dest='outfile', metavar='OUTFILE',
help="Write the XML report to this file. Defaults to 'coverage.xml'",
)
output_json = optparse.make_option(
"-o", "", action="store", dest="outfile", metavar="OUTFILE",
'-o', '', action='store', dest='outfile', metavar='OUTFILE',
help="Write the JSON report to this file. Defaults to 'coverage.json'",
)
output_lcov = optparse.make_option(
"-o", "", action="store", dest="outfile", metavar="OUTFILE",
'-o', '', action='store', dest='outfile', metavar='OUTFILE',
help="Write the LCOV report to this file. Defaults to 'coverage.lcov'",
)
json_pretty_print = optparse.make_option(
"", "--pretty-print", action="store_true",
help="Format the JSON for human readers.",
'', '--pretty-print', action='store_true',
help='Format the JSON for human readers.',
)
parallel_mode = optparse.make_option(
"-p", "--parallel-mode", action="store_true",
'-p', '--parallel-mode', action='store_true',
help=(
"Append the machine name, process id and random number to the " +
"data file name to simplify collecting data from " +
"many processes."
'Append the machine name, process id and random number to the ' +
'data file name to simplify collecting data from ' +
'many processes.'
),
)
precision = optparse.make_option(
"", "--precision", action="store", metavar="N", type=int,
'', '--precision', action='store', metavar='N', type=int,
help=(
"Number of digits after the decimal point to display for " +
"reported coverage percentages."
'Number of digits after the decimal point to display for ' +
'reported coverage percentages.'
),
)
quiet = optparse.make_option(
"-q", "--quiet", action="store_true",
'-q', '--quiet', action='store_true',
help="Don't print messages about what is happening.",
)
rcfile = optparse.make_option(
"", "--rcfile", action="store",
'', '--rcfile', action='store',
help=(
"Specify configuration file. " +
'Specify configuration file. ' +
"By default '.coveragerc', 'setup.cfg', 'tox.ini', and " +
"'pyproject.toml' are tried. [env: COVERAGE_RCFILE]"
),
)
show_contexts = optparse.make_option(
"--show-contexts", action="store_true",
help="Show contexts for covered lines.",
'--show-contexts', action='store_true',
help='Show contexts for covered lines.',
)
skip_covered = optparse.make_option(
"--skip-covered", action="store_true",
help="Skip files with 100% coverage.",
'--skip-covered', action='store_true',
help='Skip files with 100% coverage.',
)
no_skip_covered = optparse.make_option(
"--no-skip-covered", action="store_false", dest="skip_covered",
help="Disable --skip-covered.",
'--no-skip-covered', action='store_false', dest='skip_covered',
help='Disable --skip-covered.',
)
skip_empty = optparse.make_option(
"--skip-empty", action="store_true",
help="Skip files with no code.",
'--skip-empty', action='store_true',
help='Skip files with no code.',
)
sort = optparse.make_option(
"--sort", action="store", metavar="COLUMN",
'--sort', action='store', metavar='COLUMN',
help=(
"Sort the report by the named column: name, stmts, miss, branch, brpart, or cover. " +
"Default is name."
'Sort the report by the named column: name, stmts, miss, branch, brpart, or cover. ' +
'Default is name.'
),
)
source = optparse.make_option(
"", "--source", action="store", metavar="SRC1,SRC2,...",
help="A list of directories or importable names of code to measure.",
'', '--source', action='store', metavar='SRC1,SRC2,...',
help='A list of directories or importable names of code to measure.',
)
timid = optparse.make_option(
"", "--timid", action="store_true",
help="Use the slower Python trace function core.",
'', '--timid', action='store_true',
help='Use the slower Python trace function core.',
)
title = optparse.make_option(
"", "--title", action="store", metavar="TITLE",
help="A text string to use as the title on the HTML.",
'', '--title', action='store', metavar='TITLE',
help='A text string to use as the title on the HTML.',
)
version = optparse.make_option(
"", "--version", action="store_true",
help="Display version information and exit.",
'', '--version', action='store_true',
help='Display version information and exit.',
)
@ -238,7 +243,7 @@ class CoverageOptionParser(optparse.OptionParser):
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["add_help_option"] = False
kwargs['add_help_option'] = False
super().__init__(*args, **kwargs)
self.set_defaults(
# Keep these arguments alphabetized by their names.
@ -330,7 +335,7 @@ class CmdOptionParser(CoverageOptionParser):
"""
if usage:
usage = "%prog " + usage
usage = '%prog ' + usage
super().__init__(
usage=usage,
description=description,
@ -342,7 +347,7 @@ class CmdOptionParser(CoverageOptionParser):
def __eq__(self, other: str) -> bool: # type: ignore[override]
# A convenience equality, so that I can put strings in unit test
# results, and they will compare equal to objects.
return (other == f"<CmdOptionParser:{self.cmd}>")
return (other == f'<CmdOptionParser:{self.cmd}>')
__hash__ = None # type: ignore[assignment]
@ -351,7 +356,7 @@ class CmdOptionParser(CoverageOptionParser):
program_name = super().get_prog_name()
# Include the sub-command for this parser as part of the command.
return f"{program_name} {self.cmd}"
return f'{program_name} {self.cmd}'
# In lists of Opts, keep them alphabetized by the option names as they appear
# on the command line, since these lists determine the order of the options in
@ -359,6 +364,7 @@ class CmdOptionParser(CoverageOptionParser):
#
# In COMMANDS, keep the keys (command names) alphabetized.
GLOBAL_ARGS = [
Opts.debug,
Opts.help,
@ -366,72 +372,72 @@ GLOBAL_ARGS = [
]
COMMANDS = {
"annotate": CmdOptionParser(
"annotate",
'annotate': CmdOptionParser(
'annotate',
[
Opts.directory,
Opts.datafle_input,
Opts.ignore_errors,
Opts.include,
Opts.omit,
] + GLOBAL_ARGS,
usage="[options] [modules]",
] + GLOBAL_ARGS,
usage='[options] [modules]',
description=(
"Make annotated copies of the given files, marking statements that are executed " +
"with > and statements that are missed with !."
'Make annotated copies of the given files, marking statements that are executed ' +
'with > and statements that are missed with !.'
),
),
"combine": CmdOptionParser(
"combine",
'combine': CmdOptionParser(
'combine',
[
Opts.append,
Opts.datafile,
Opts.keep,
Opts.quiet,
] + GLOBAL_ARGS,
usage="[options] <path1> <path2> ... <pathN>",
] + GLOBAL_ARGS,
usage='[options] <path1> <path2> ... <pathN>',
description=(
"Combine data from multiple coverage files. " +
"The combined results are written to a single " +
"file representing the union of the data. The positional " +
"arguments are data files or directories containing data files. " +
'Combine data from multiple coverage files. ' +
'The combined results are written to a single ' +
'file representing the union of the data. The positional ' +
'arguments are data files or directories containing data files. ' +
"If no paths are provided, data files in the default data file's " +
"directory are combined."
'directory are combined.'
),
),
"debug": CmdOptionParser(
"debug", GLOBAL_ARGS,
usage="<topic>",
'debug': CmdOptionParser(
'debug', GLOBAL_ARGS,
usage='<topic>',
description=(
"Display information about the internals of coverage.py, " +
"for diagnosing problems. " +
"Topics are: " +
"'data' to show a summary of the collected data; " +
"'sys' to show installation information; " +
"'config' to show the configuration; " +
"'premain' to show what is calling coverage; " +
"'pybehave' to show internal flags describing Python behavior."
'Display information about the internals of coverage.py, ' +
'for diagnosing problems. ' +
'Topics are: ' +
"'data' to show a summary of the collected data; " +
"'sys' to show installation information; " +
"'config' to show the configuration; " +
"'premain' to show what is calling coverage; " +
"'pybehave' to show internal flags describing Python behavior."
),
),
"erase": CmdOptionParser(
"erase",
'erase': CmdOptionParser(
'erase',
[
Opts.datafile,
] + GLOBAL_ARGS,
description="Erase previously collected coverage data.",
] + GLOBAL_ARGS,
description='Erase previously collected coverage data.',
),
"help": CmdOptionParser(
"help", GLOBAL_ARGS,
usage="[command]",
description="Describe how to use coverage.py",
'help': CmdOptionParser(
'help', GLOBAL_ARGS,
usage='[command]',
description='Describe how to use coverage.py',
),
"html": CmdOptionParser(
"html",
'html': CmdOptionParser(
'html',
[
Opts.contexts,
Opts.directory,
@ -447,17 +453,17 @@ COMMANDS = {
Opts.no_skip_covered,
Opts.skip_empty,
Opts.title,
] + GLOBAL_ARGS,
usage="[options] [modules]",
] + GLOBAL_ARGS,
usage='[options] [modules]',
description=(
"Create an HTML report of the coverage of the files. " +
"Each file gets its own page, with the source decorated to show " +
"executed, excluded, and missed lines."
'Create an HTML report of the coverage of the files. ' +
'Each file gets its own page, with the source decorated to show ' +
'executed, excluded, and missed lines.'
),
),
"json": CmdOptionParser(
"json",
'json': CmdOptionParser(
'json',
[
Opts.contexts,
Opts.datafle_input,
@ -469,13 +475,13 @@ COMMANDS = {
Opts.json_pretty_print,
Opts.quiet,
Opts.show_contexts,
] + GLOBAL_ARGS,
usage="[options] [modules]",
description="Generate a JSON report of coverage results.",
] + GLOBAL_ARGS,
usage='[options] [modules]',
description='Generate a JSON report of coverage results.',
),
"lcov": CmdOptionParser(
"lcov",
'lcov': CmdOptionParser(
'lcov',
[
Opts.datafle_input,
Opts.fail_under,
@ -484,13 +490,13 @@ COMMANDS = {
Opts.output_lcov,
Opts.omit,
Opts.quiet,
] + GLOBAL_ARGS,
usage="[options] [modules]",
description="Generate an LCOV report of coverage results.",
] + GLOBAL_ARGS,
usage='[options] [modules]',
description='Generate an LCOV report of coverage results.',
),
"report": CmdOptionParser(
"report",
'report': CmdOptionParser(
'report',
[
Opts.contexts,
Opts.datafle_input,
@ -505,13 +511,13 @@ COMMANDS = {
Opts.skip_covered,
Opts.no_skip_covered,
Opts.skip_empty,
] + GLOBAL_ARGS,
usage="[options] [modules]",
description="Report coverage statistics on modules.",
] + GLOBAL_ARGS,
usage='[options] [modules]',
description='Report coverage statistics on modules.',
),
"run": CmdOptionParser(
"run",
'run': CmdOptionParser(
'run',
[
Opts.append,
Opts.branch,
@ -525,13 +531,13 @@ COMMANDS = {
Opts.parallel_mode,
Opts.source,
Opts.timid,
] + GLOBAL_ARGS,
usage="[options] <pyfile> [program options]",
description="Run a Python program, measuring code execution.",
] + GLOBAL_ARGS,
usage='[options] <pyfile> [program options]',
description='Run a Python program, measuring code execution.',
),
"xml": CmdOptionParser(
"xml",
'xml': CmdOptionParser(
'xml',
[
Opts.datafle_input,
Opts.fail_under,
@ -541,9 +547,9 @@ COMMANDS = {
Opts.output_xml,
Opts.quiet,
Opts.skip_empty,
] + GLOBAL_ARGS,
usage="[options] [modules]",
description="Generate an XML report of coverage results.",
] + GLOBAL_ARGS,
usage='[options] [modules]',
description='Generate an XML report of coverage results.',
),
}
@ -557,7 +563,7 @@ def show_help(
assert error or topic or parser
program_path = sys.argv[0]
if program_path.endswith(os.path.sep + "__main__.py"):
if program_path.endswith(os.path.sep + '__main__.py'):
# The path is the main module of a package; get that path instead.
program_path = os.path.dirname(program_path)
program_name = os.path.basename(program_path)
@ -567,17 +573,17 @@ def show_help(
# invoke coverage-script.py, coverage3-script.py, and
# coverage-3.5-script.py. argv[0] is the .py file, but we want to
# get back to the original form.
auto_suffix = "-script.py"
auto_suffix = '-script.py'
if program_name.endswith(auto_suffix):
program_name = program_name[:-len(auto_suffix)]
help_params = dict(coverage.__dict__)
help_params["__url__"] = __url__
help_params["program_name"] = program_name
help_params['__url__'] = __url__
help_params['program_name'] = program_name
if HAS_CTRACER:
help_params["extension_modifier"] = "with C extension"
help_params['extension_modifier'] = 'with C extension'
else:
help_params["extension_modifier"] = "without C extension"
help_params['extension_modifier'] = 'without C extension'
if error:
print(error, file=sys.stderr)
@ -587,12 +593,12 @@ def show_help(
print()
else:
assert topic is not None
help_msg = textwrap.dedent(HELP_TOPICS.get(topic, "")).strip()
help_msg = textwrap.dedent(HELP_TOPICS.get(topic, '')).strip()
if help_msg:
print(help_msg.format(**help_params))
else:
print(f"Don't know topic {topic!r}")
print("Full documentation is at {__url__}".format(**help_params))
print('Full documentation is at {__url__}'.format(**help_params))
OK, ERR, FAIL_UNDER = 0, 1, 2
@ -615,19 +621,19 @@ class CoverageScript:
"""
# Collect the command-line options.
if not argv:
show_help(topic="minimum_help")
show_help(topic='minimum_help')
return OK
# The command syntax we parse depends on the first argument. Global
# switch syntax always starts with an option.
parser: optparse.OptionParser | None
self.global_option = argv[0].startswith("-")
self.global_option = argv[0].startswith('-')
if self.global_option:
parser = GlobalOptionParser()
else:
parser = COMMANDS.get(argv[0])
if not parser:
show_help(f"Unknown command: {argv[0]!r}")
show_help(f'Unknown command: {argv[0]!r}')
return ERR
argv = argv[1:]
@ -648,7 +654,7 @@ class CoverageScript:
contexts = unshell_list(options.contexts)
if options.concurrency is not None:
concurrency = options.concurrency.split(",")
concurrency = options.concurrency.split(',')
else:
concurrency = None
@ -670,17 +676,17 @@ class CoverageScript:
messages=not options.quiet,
)
if options.action == "debug":
if options.action == 'debug':
return self.do_debug(args)
elif options.action == "erase":
elif options.action == 'erase':
self.coverage.erase()
return OK
elif options.action == "run":
elif options.action == 'run':
return self.do_run(options, args)
elif options.action == "combine":
elif options.action == 'combine':
if options.append:
self.coverage.load()
data_paths = args or None
@ -699,12 +705,12 @@ class CoverageScript:
# We need to be able to import from the current directory, because
# plugins may try to, for example, to read Django settings.
sys.path.insert(0, "")
sys.path.insert(0, '')
self.coverage.load()
total = None
if options.action == "report":
if options.action == 'report':
total = self.coverage.report(
precision=options.precision,
show_missing=options.show_missing,
@ -714,9 +720,9 @@ class CoverageScript:
output_format=options.format,
**report_args,
)
elif options.action == "annotate":
elif options.action == 'annotate':
self.coverage.annotate(directory=options.directory, **report_args)
elif options.action == "html":
elif options.action == 'html':
total = self.coverage.html_report(
directory=options.directory,
precision=options.precision,
@ -726,20 +732,20 @@ class CoverageScript:
title=options.title,
**report_args,
)
elif options.action == "xml":
elif options.action == 'xml':
total = self.coverage.xml_report(
outfile=options.outfile,
skip_empty=options.skip_empty,
**report_args,
)
elif options.action == "json":
elif options.action == 'json':
total = self.coverage.json_report(
outfile=options.outfile,
pretty_print=options.pretty_print,
show_contexts=options.show_contexts,
**report_args,
)
elif options.action == "lcov":
elif options.action == 'lcov':
total = self.coverage.lcov_report(
outfile=options.outfile,
**report_args,
@ -752,19 +758,19 @@ class CoverageScript:
# Apply the command line fail-under options, and then use the config
# value, so we can get fail_under from the config file.
if options.fail_under is not None:
self.coverage.set_option("report:fail_under", options.fail_under)
self.coverage.set_option('report:fail_under', options.fail_under)
if options.precision is not None:
self.coverage.set_option("report:precision", options.precision)
self.coverage.set_option('report:precision', options.precision)
fail_under = cast(float, self.coverage.get_option("report:fail_under"))
precision = cast(int, self.coverage.get_option("report:precision"))
fail_under = cast(float, self.coverage.get_option('report:fail_under'))
precision = cast(int, self.coverage.get_option('report:precision'))
if should_fail_under(total, fail_under, precision):
msg = "total of {total} is less than fail-under={fail_under:.{p}f}".format(
msg = 'total of {total} is less than fail-under={fail_under:.{p}f}'.format(
total=Numbers(precision=precision).display_covered(total),
fail_under=fail_under,
p=precision,
)
print("Coverage failure:", msg)
print('Coverage failure:', msg)
return FAIL_UNDER
return OK
@ -783,12 +789,12 @@ class CoverageScript:
# Handle help.
if options.help:
if self.global_option:
show_help(topic="help")
show_help(topic='help')
else:
show_help(parser=parser)
return True
if options.action == "help":
if options.action == 'help':
if args:
for a in args:
parser_maybe = COMMANDS.get(a)
@ -797,12 +803,12 @@ class CoverageScript:
else:
show_help(topic=a)
else:
show_help(topic="help")
show_help(topic='help')
return True
# Handle version.
if options.version:
show_help(topic="version")
show_help(topic='version')
return True
return False
@ -813,37 +819,37 @@ class CoverageScript:
if not args:
if options.module:
# Specified -m with nothing else.
show_help("No module specified for -m")
show_help('No module specified for -m')
return ERR
command_line = cast(str, self.coverage.get_option("run:command_line"))
command_line = cast(str, self.coverage.get_option('run:command_line'))
if command_line is not None:
args = shlex.split(command_line)
if args and args[0] in {"-m", "--module"}:
if args and args[0] in {'-m', '--module'}:
options.module = True
args = args[1:]
if not args:
show_help("Nothing to do.")
show_help('Nothing to do.')
return ERR
if options.append and self.coverage.get_option("run:parallel"):
if options.append and self.coverage.get_option('run:parallel'):
show_help("Can't append to data files in parallel mode.")
return ERR
if options.concurrency == "multiprocessing":
if options.concurrency == 'multiprocessing':
# Can't set other run-affecting command line options with
# multiprocessing.
for opt_name in ["branch", "include", "omit", "pylib", "source", "timid"]:
for opt_name in ['branch', 'include', 'omit', 'pylib', 'source', 'timid']:
# As it happens, all of these options have no default, meaning
# they will be None if they have not been specified.
if getattr(options, opt_name) is not None:
show_help(
"Options affecting multiprocessing must only be specified " +
"in a configuration file.\n" +
f"Remove --{opt_name} from the command line.",
'Options affecting multiprocessing must only be specified ' +
'in a configuration file.\n' +
f'Remove --{opt_name} from the command line.',
)
return ERR
os.environ["COVERAGE_RUN"] = "true"
os.environ['COVERAGE_RUN'] = 'true'
runner = PyRunner(args, as_module=bool(options.module))
runner.prepare()
@ -870,28 +876,28 @@ class CoverageScript:
"""Implementation of 'coverage debug'."""
if not args:
show_help("What information would you like: config, data, sys, premain, pybehave?")
show_help('What information would you like: config, data, sys, premain, pybehave?')
return ERR
if args[1:]:
show_help("Only one topic at a time, please")
show_help('Only one topic at a time, please')
return ERR
if args[0] == "sys":
write_formatted_info(print, "sys", self.coverage.sys_info())
elif args[0] == "data":
print(info_header("data"))
if args[0] == 'sys':
write_formatted_info(print, 'sys', self.coverage.sys_info())
elif args[0] == 'data':
print(info_header('data'))
data_file = self.coverage.config.data_file
debug_data_file(data_file)
for filename in combinable_files(data_file):
print("-----")
print('-----')
debug_data_file(filename)
elif args[0] == "config":
write_formatted_info(print, "config", self.coverage.config.debug_info())
elif args[0] == "premain":
print(info_header("premain"))
elif args[0] == 'config':
write_formatted_info(print, 'config', self.coverage.config.debug_info())
elif args[0] == 'premain':
print(info_header('premain'))
print(short_stack(full=True))
elif args[0] == "pybehave":
write_formatted_info(print, "pybehave", env.debug_info())
elif args[0] == 'pybehave':
write_formatted_info(print, 'pybehave', env.debug_info())
else:
show_help(f"Don't know what you mean by {args[0]!r}")
return ERR
@ -910,7 +916,7 @@ def unshell_list(s: str) -> list[str] | None:
# line, but (not) helpfully, the single quotes are included in the
# argument, so we have to strip them off here.
s = s.strip("'")
return s.split(",")
return s.split(',')
def unglob_args(args: list[str]) -> list[str]:
@ -918,7 +924,7 @@ def unglob_args(args: list[str]) -> list[str]:
if env.WINDOWS:
globbed = []
for arg in args:
if "?" in arg or "*" in arg:
if '?' in arg or '*' in arg:
globbed.extend(glob.glob(arg))
else:
globbed.append(arg)
@ -927,7 +933,7 @@ def unglob_args(args: list[str]) -> list[str]:
HELP_TOPICS = {
"help": """\
'help': """\
Coverage.py, version {__version__} {extension_modifier}
Measure, collect, and report on code coverage in Python programs.
@ -949,12 +955,12 @@ HELP_TOPICS = {
Use "{program_name} help <command>" for detailed help on any command.
""",
"minimum_help": (
"Code coverage for Python, version {__version__} {extension_modifier}. " +
'minimum_help': (
'Code coverage for Python, version {__version__} {extension_modifier}. ' +
"Use '{program_name} help' for help."
),
"version": "Coverage.py, version {__version__} {extension_modifier}",
'version': 'Coverage.py, version {__version__} {extension_modifier}',
}
@ -987,11 +993,12 @@ def main(argv: list[str] | None = None) -> int | None:
status = None
return status
# Profiling using ox_profile. Install it from GitHub:
# pip install git+https://github.com/emin63/ox_profile.git
#
# $set_env.py: COVERAGE_PROFILE - Set to use ox_profile.
_profile = os.getenv("COVERAGE_PROFILE")
_profile = os.getenv('COVERAGE_PROFILE')
if _profile: # pragma: debugging
from ox_profile.core.launchers import SimpleLauncher # pylint: disable=import-error
original_main = main
@ -1004,6 +1011,6 @@ if _profile: # pragma: debugging
try:
return original_main(argv)
finally:
data, _ = profiler.query(re_filter="coverage", max_records=100)
print(profiler.show(query=data, limit=100, sep="", col=""))
data, _ = profiler.query(re_filter='coverage', max_records=100)
print(profiler.show(query=data, limit=100, sep='', col=''))
profiler.cancel()

View file

@ -1,18 +1,20 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Raw data collector for coverage.py."""
from __future__ import annotations
import functools
import os
import sys
from types import FrameType
from typing import (
cast, Any, Callable, Dict, List, Mapping, Set, TypeVar,
)
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import List
from typing import Mapping
from typing import Set
from typing import TypeVar
from coverage import env
from coverage.config import CoverageConfig
@ -20,13 +22,17 @@ from coverage.data import CoverageData
from coverage.debug import short_stack
from coverage.disposition import FileDisposition
from coverage.exceptions import ConfigError
from coverage.misc import human_sorted_items, isolate_module
from coverage.misc import human_sorted_items
from coverage.misc import isolate_module
from coverage.plugin import CoveragePlugin
from coverage.pytracer import PyTracer
from coverage.sysmon import SysMonitor
from coverage.types import (
TArc, TFileDisposition, TTraceData, TTraceFn, TracerCore, TWarnFn,
)
from coverage.types import TArc
from coverage.types import TFileDisposition
from coverage.types import TracerCore
from coverage.types import TTraceData
from coverage.types import TTraceFn
from coverage.types import TWarnFn
os = isolate_module(os)
@ -37,7 +43,7 @@ try:
HAS_CTRACER = True
except ImportError:
# Couldn't import the C extension, maybe it isn't built.
if os.getenv("COVERAGE_CORE") == "ctrace": # pragma: part covered
if os.getenv('COVERAGE_CORE') == 'ctrace': # pragma: part covered
# During testing, we use the COVERAGE_CORE environment variable
# to indicate that we've fiddled with the environment to test this
# fallback code. If we thought we had a C tracer, but couldn't import
@ -48,7 +54,7 @@ except ImportError:
sys.exit(1)
HAS_CTRACER = False
T = TypeVar("T")
T = TypeVar('T')
class Collector:
@ -73,7 +79,7 @@ class Collector:
_collectors: list[Collector] = []
# The concurrency settings we support here.
LIGHT_THREADS = {"greenlet", "eventlet", "gevent"}
LIGHT_THREADS = {'greenlet', 'eventlet', 'gevent'}
def __init__(
self,
@ -130,7 +136,7 @@ class Collector:
self.branch = branch
self.warn = warn
self.concurrency = concurrency
assert isinstance(self.concurrency, list), f"Expected a list: {self.concurrency!r}"
assert isinstance(self.concurrency, list), f'Expected a list: {self.concurrency!r}'
self.pid = os.getpid()
@ -147,12 +153,12 @@ class Collector:
core: str | None
if timid:
core = "pytrace"
core = 'pytrace'
else:
core = os.getenv("COVERAGE_CORE")
core = os.getenv('COVERAGE_CORE')
if core == "sysmon" and not env.PYBEHAVIOR.pep669:
self.warn("sys.monitoring isn't available, using default core", slug="no-sysmon")
if core == 'sysmon' and not env.PYBEHAVIOR.pep669:
self.warn("sys.monitoring isn't available, using default core", slug='no-sysmon')
core = None
if not core:
@ -160,25 +166,25 @@ class Collector:
# if env.PYBEHAVIOR.pep669 and self.should_start_context is None:
# core = "sysmon"
if HAS_CTRACER:
core = "ctrace"
core = 'ctrace'
else:
core = "pytrace"
core = 'pytrace'
if core == "sysmon":
if core == 'sysmon':
self._trace_class = SysMonitor
self._core_kwargs = {"tool_id": 3 if metacov else 1}
self._core_kwargs = {'tool_id': 3 if metacov else 1}
self.file_disposition_class = FileDisposition
self.supports_plugins = False
self.packed_arcs = False
self.systrace = False
elif core == "ctrace":
elif core == 'ctrace':
self._trace_class = CTracer
self._core_kwargs = {}
self.file_disposition_class = CFileDisposition
self.supports_plugins = True
self.packed_arcs = True
self.systrace = True
elif core == "pytrace":
elif core == 'pytrace':
self._trace_class = PyTracer
self._core_kwargs = {}
self.file_disposition_class = FileDisposition
@ -186,42 +192,42 @@ class Collector:
self.packed_arcs = False
self.systrace = True
else:
raise ConfigError(f"Unknown core value: {core!r}")
raise ConfigError(f'Unknown core value: {core!r}')
# We can handle a few concurrency options here, but only one at a time.
concurrencies = set(self.concurrency)
unknown = concurrencies - CoverageConfig.CONCURRENCY_CHOICES
if unknown:
show = ", ".join(sorted(unknown))
raise ConfigError(f"Unknown concurrency choices: {show}")
show = ', '.join(sorted(unknown))
raise ConfigError(f'Unknown concurrency choices: {show}')
light_threads = concurrencies & self.LIGHT_THREADS
if len(light_threads) > 1:
show = ", ".join(sorted(light_threads))
raise ConfigError(f"Conflicting concurrency settings: {show}")
show = ', '.join(sorted(light_threads))
raise ConfigError(f'Conflicting concurrency settings: {show}')
do_threading = False
tried = "nothing" # to satisfy pylint
tried = 'nothing' # to satisfy pylint
try:
if "greenlet" in concurrencies:
tried = "greenlet"
if 'greenlet' in concurrencies:
tried = 'greenlet'
import greenlet
self.concur_id_func = greenlet.getcurrent
elif "eventlet" in concurrencies:
tried = "eventlet"
elif 'eventlet' in concurrencies:
tried = 'eventlet'
import eventlet.greenthread # pylint: disable=import-error,useless-suppression
self.concur_id_func = eventlet.greenthread.getcurrent
elif "gevent" in concurrencies:
tried = "gevent"
elif 'gevent' in concurrencies:
tried = 'gevent'
import gevent # pylint: disable=import-error,useless-suppression
self.concur_id_func = gevent.getcurrent
if "thread" in concurrencies:
if 'thread' in concurrencies:
do_threading = True
except ImportError as ex:
msg = f"Couldn't trace with concurrency={tried}, the module isn't installed."
raise ConfigError(msg) from ex
if self.concur_id_func and not hasattr(self._trace_class, "concur_id_func"):
if self.concur_id_func and not hasattr(self._trace_class, 'concur_id_func'):
raise ConfigError(
"Can't support concurrency={} with {}, only threads are supported.".format(
tried, self.tracer_name(),
@ -238,7 +244,7 @@ class Collector:
self.reset()
def __repr__(self) -> str:
return f"<Collector at {id(self):#x}: {self.tracer_name()}>"
return f'<Collector at {id(self):#x}: {self.tracer_name()}>'
def use_data(self, covdata: CoverageData, context: str | None) -> None:
"""Use `covdata` for recording data."""
@ -296,7 +302,7 @@ class Collector:
#
# This gives a 20% benefit on the workload described at
# https://bitbucket.org/pypy/pypy/issue/1871/10x-slower-than-cpython-under-coverage
self.should_trace_cache = __pypy__.newdict("module")
self.should_trace_cache = __pypy__.newdict('module')
else:
self.should_trace_cache = {}
@ -394,9 +400,9 @@ class Collector:
"""Stop collecting trace information."""
assert self._collectors
if self._collectors[-1] is not self:
print("self._collectors:")
print('self._collectors:')
for c in self._collectors:
print(f" {c!r}\n{c.origin}")
print(f' {c!r}\n{c.origin}')
assert self._collectors[-1] is self, (
f"Expected current collector to be {self!r}, but it's {self._collectors[-1]!r}"
)
@ -414,9 +420,9 @@ class Collector:
tracer.stop()
stats = tracer.get_stats()
if stats:
print("\nCoverage.py tracer stats:")
print('\nCoverage.py tracer stats:')
for k, v in human_sorted_items(stats.items()):
print(f"{k:>20}: {v}")
print(f'{k:>20}: {v}')
if self.threading:
self.threading.settrace(None)
@ -433,7 +439,7 @@ class Collector:
def post_fork(self) -> None:
"""After a fork, tracers might need to adjust."""
for tracer in self.tracers:
if hasattr(tracer, "post_fork"):
if hasattr(tracer, 'post_fork'):
tracer.post_fork()
def _activity(self) -> bool:
@ -451,7 +457,7 @@ class Collector:
if self.static_context:
context = self.static_context
if new_context:
context += "|" + new_context
context += '|' + new_context
else:
context = new_context
self.covdata.set_context(context)
@ -462,7 +468,7 @@ class Collector:
assert file_tracer is not None
plugin = file_tracer._coverage_plugin
plugin_name = plugin._coverage_plugin_name
self.warn(f"Disabling plug-in {plugin_name!r} due to previous exception")
self.warn(f'Disabling plug-in {plugin_name!r} due to previous exception')
plugin._coverage_enabled = False
disposition.trace = False

View file

@ -1,28 +1,30 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Config file for coverage.py"""
from __future__ import annotations
import collections
import configparser
import copy
import os
import os.path
import re
from typing import (
Any, Callable, Iterable, Union,
)
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Union
from coverage.exceptions import ConfigError
from coverage.misc import isolate_module, human_sorted_items, substitute_variables
from coverage.tomlconfig import TomlConfigParser, TomlDecodeError
from coverage.types import (
TConfigurable, TConfigSectionIn, TConfigValueIn, TConfigSectionOut,
TConfigValueOut, TPluginConfig,
)
from coverage.misc import human_sorted_items
from coverage.misc import isolate_module
from coverage.misc import substitute_variables
from coverage.tomlconfig import TomlConfigParser
from coverage.tomlconfig import TomlDecodeError
from coverage.types import TConfigSectionIn
from coverage.types import TConfigSectionOut
from coverage.types import TConfigurable
from coverage.types import TConfigValueIn
from coverage.types import TConfigValueOut
from coverage.types import TPluginConfig
os = isolate_module(os)
@ -39,17 +41,17 @@ class HandyConfigParser(configparser.ConfigParser):
"""
super().__init__(interpolation=None)
self.section_prefixes = ["coverage:"]
self.section_prefixes = ['coverage:']
if our_file:
self.section_prefixes.append("")
self.section_prefixes.append('')
def read( # type: ignore[override]
def read( # type: ignore[override]
self,
filenames: Iterable[str],
encoding_unused: str | None = None,
) -> list[str]:
"""Read a file name as UTF-8 configuration data."""
return super().read(filenames, encoding="utf-8")
return super().read(filenames, encoding='utf-8')
def real_section(self, section: str) -> str | None:
"""Get the actual name of a section."""
@ -73,7 +75,7 @@ class HandyConfigParser(configparser.ConfigParser):
real_section = self.real_section(section)
if real_section is not None:
return super().options(real_section)
raise ConfigError(f"No section: {section!r}")
raise ConfigError(f'No section: {section!r}')
def get_section(self, section: str) -> TConfigSectionOut:
"""Get the contents of a section, as a dictionary."""
@ -82,7 +84,7 @@ class HandyConfigParser(configparser.ConfigParser):
d[opt] = self.get(section, opt)
return d
def get(self, section: str, option: str, *args: Any, **kwargs: Any) -> str: # type: ignore
def get(self, section: str, option: str, *args: Any, **kwargs: Any) -> str: # type: ignore
"""Get a value, replacing environment variables also.
The arguments are the same as `ConfigParser.get`, but in the found
@ -97,7 +99,7 @@ class HandyConfigParser(configparser.ConfigParser):
if super().has_option(real_section, option):
break
else:
raise ConfigError(f"No option {option!r} in section: {section!r}")
raise ConfigError(f'No option {option!r} in section: {section!r}')
v: str = super().get(real_section, option, *args, **kwargs)
v = substitute_variables(v, os.environ)
@ -114,8 +116,8 @@ class HandyConfigParser(configparser.ConfigParser):
"""
value_list = self.get(section, option)
values = []
for value_line in value_list.split("\n"):
for value in value_line.split(","):
for value_line in value_list.split('\n'):
for value in value_line.split(','):
value = value.strip()
if value:
values.append(value)
@ -138,7 +140,7 @@ class HandyConfigParser(configparser.ConfigParser):
re.compile(value)
except re.error as e:
raise ConfigError(
f"Invalid [{section}].{option} value {value!r}: {e}",
f'Invalid [{section}].{option} value {value!r}: {e}',
) from e
if value:
value_list.append(value)
@ -150,20 +152,20 @@ TConfigParser = Union[HandyConfigParser, TomlConfigParser]
# The default line exclusion regexes.
DEFAULT_EXCLUDE = [
r"#\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(cover|COVER)",
r'#\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(cover|COVER)',
]
# The default partial branch regexes, to be modified by the user.
DEFAULT_PARTIAL = [
r"#\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(branch|BRANCH)",
r'#\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(branch|BRANCH)',
]
# The default partial branch regexes, based on Python semantics.
# These are any Python branching constructs that can't actually execute all
# their branches.
DEFAULT_PARTIAL_ALWAYS = [
"while (True|1|False|0):",
"if (True|1|False|0):",
'while (True|1|False|0):',
'if (True|1|False|0):',
]
@ -197,7 +199,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
self.concurrency: list[str] = []
self.context: str | None = None
self.cover_pylib = False
self.data_file = ".coverage"
self.data_file = '.coverage'
self.debug: list[str] = []
self.debug_file: str | None = None
self.disable_warnings: list[str] = []
@ -233,23 +235,23 @@ class CoverageConfig(TConfigurable, TPluginConfig):
# Defaults for [html]
self.extra_css: str | None = None
self.html_dir = "htmlcov"
self.html_dir = 'htmlcov'
self.html_skip_covered: bool | None = None
self.html_skip_empty: bool | None = None
self.html_title = "Coverage report"
self.html_title = 'Coverage report'
self.show_contexts = False
# Defaults for [xml]
self.xml_output = "coverage.xml"
self.xml_output = 'coverage.xml'
self.xml_package_depth = 99
# Defaults for [json]
self.json_output = "coverage.json"
self.json_output = 'coverage.json'
self.json_pretty_print = False
self.json_show_contexts = False
# Defaults for [lcov]
self.lcov_output = "coverage.lcov"
self.lcov_output = 'coverage.lcov'
# Defaults for [paths]
self.paths: dict[str, list[str]] = {}
@ -258,9 +260,9 @@ class CoverageConfig(TConfigurable, TPluginConfig):
self.plugin_options: dict[str, TConfigSectionOut] = {}
MUST_BE_LIST = {
"debug", "concurrency", "plugins",
"report_omit", "report_include",
"run_omit", "run_include",
'debug', 'concurrency', 'plugins',
'report_omit', 'report_include',
'run_omit', 'run_include',
}
def from_args(self, **kwargs: TConfigValueIn) -> None:
@ -286,7 +288,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
"""
_, ext = os.path.splitext(filename)
cp: TConfigParser
if ext == ".toml":
if ext == '.toml':
cp = TomlConfigParser(our_file)
else:
cp = HandyConfigParser(our_file)
@ -314,7 +316,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
# Check that there are no unrecognized options.
all_options = collections.defaultdict(set)
for option_spec in self.CONFIG_FILE_OPTIONS:
section, option = option_spec[1].split(":")
section, option = option_spec[1].split(':')
all_options[section].add(option)
for section, options in all_options.items():
@ -328,9 +330,9 @@ class CoverageConfig(TConfigurable, TPluginConfig):
)
# [paths] is special
if cp.has_section("paths"):
for option in cp.options("paths"):
self.paths[option] = cp.getlist("paths", option)
if cp.has_section('paths'):
for option in cp.options('paths'):
self.paths[option] = cp.getlist('paths', option)
any_set = True
# plugins can have options
@ -349,7 +351,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
if used:
self.config_file = os.path.abspath(filename)
with open(filename, "rb") as f:
with open(filename, 'rb') as f:
self._config_contents = f.read()
return used
@ -358,7 +360,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
"""Return a copy of the configuration."""
return copy.deepcopy(self)
CONCURRENCY_CHOICES = {"thread", "gevent", "greenlet", "eventlet", "multiprocessing"}
CONCURRENCY_CHOICES = {'thread', 'gevent', 'greenlet', 'eventlet', 'multiprocessing'}
CONFIG_FILE_OPTIONS = [
# These are *args for _set_attr_from_config_option:
@ -370,64 +372,64 @@ class CoverageConfig(TConfigurable, TPluginConfig):
# configuration value from the file.
# [run]
("branch", "run:branch", "boolean"),
("command_line", "run:command_line"),
("concurrency", "run:concurrency", "list"),
("context", "run:context"),
("cover_pylib", "run:cover_pylib", "boolean"),
("data_file", "run:data_file"),
("debug", "run:debug", "list"),
("debug_file", "run:debug_file"),
("disable_warnings", "run:disable_warnings", "list"),
("dynamic_context", "run:dynamic_context"),
("parallel", "run:parallel", "boolean"),
("plugins", "run:plugins", "list"),
("relative_files", "run:relative_files", "boolean"),
("run_include", "run:include", "list"),
("run_omit", "run:omit", "list"),
("sigterm", "run:sigterm", "boolean"),
("source", "run:source", "list"),
("source_pkgs", "run:source_pkgs", "list"),
("timid", "run:timid", "boolean"),
("_crash", "run:_crash"),
('branch', 'run:branch', 'boolean'),
('command_line', 'run:command_line'),
('concurrency', 'run:concurrency', 'list'),
('context', 'run:context'),
('cover_pylib', 'run:cover_pylib', 'boolean'),
('data_file', 'run:data_file'),
('debug', 'run:debug', 'list'),
('debug_file', 'run:debug_file'),
('disable_warnings', 'run:disable_warnings', 'list'),
('dynamic_context', 'run:dynamic_context'),
('parallel', 'run:parallel', 'boolean'),
('plugins', 'run:plugins', 'list'),
('relative_files', 'run:relative_files', 'boolean'),
('run_include', 'run:include', 'list'),
('run_omit', 'run:omit', 'list'),
('sigterm', 'run:sigterm', 'boolean'),
('source', 'run:source', 'list'),
('source_pkgs', 'run:source_pkgs', 'list'),
('timid', 'run:timid', 'boolean'),
('_crash', 'run:_crash'),
# [report]
("exclude_list", "report:exclude_lines", "regexlist"),
("exclude_also", "report:exclude_also", "regexlist"),
("fail_under", "report:fail_under", "float"),
("format", "report:format"),
("ignore_errors", "report:ignore_errors", "boolean"),
("include_namespace_packages", "report:include_namespace_packages", "boolean"),
("partial_always_list", "report:partial_branches_always", "regexlist"),
("partial_list", "report:partial_branches", "regexlist"),
("precision", "report:precision", "int"),
("report_contexts", "report:contexts", "list"),
("report_include", "report:include", "list"),
("report_omit", "report:omit", "list"),
("show_missing", "report:show_missing", "boolean"),
("skip_covered", "report:skip_covered", "boolean"),
("skip_empty", "report:skip_empty", "boolean"),
("sort", "report:sort"),
('exclude_list', 'report:exclude_lines', 'regexlist'),
('exclude_also', 'report:exclude_also', 'regexlist'),
('fail_under', 'report:fail_under', 'float'),
('format', 'report:format'),
('ignore_errors', 'report:ignore_errors', 'boolean'),
('include_namespace_packages', 'report:include_namespace_packages', 'boolean'),
('partial_always_list', 'report:partial_branches_always', 'regexlist'),
('partial_list', 'report:partial_branches', 'regexlist'),
('precision', 'report:precision', 'int'),
('report_contexts', 'report:contexts', 'list'),
('report_include', 'report:include', 'list'),
('report_omit', 'report:omit', 'list'),
('show_missing', 'report:show_missing', 'boolean'),
('skip_covered', 'report:skip_covered', 'boolean'),
('skip_empty', 'report:skip_empty', 'boolean'),
('sort', 'report:sort'),
# [html]
("extra_css", "html:extra_css"),
("html_dir", "html:directory"),
("html_skip_covered", "html:skip_covered", "boolean"),
("html_skip_empty", "html:skip_empty", "boolean"),
("html_title", "html:title"),
("show_contexts", "html:show_contexts", "boolean"),
('extra_css', 'html:extra_css'),
('html_dir', 'html:directory'),
('html_skip_covered', 'html:skip_covered', 'boolean'),
('html_skip_empty', 'html:skip_empty', 'boolean'),
('html_title', 'html:title'),
('show_contexts', 'html:show_contexts', 'boolean'),
# [xml]
("xml_output", "xml:output"),
("xml_package_depth", "xml:package_depth", "int"),
('xml_output', 'xml:output'),
('xml_package_depth', 'xml:package_depth', 'int'),
# [json]
("json_output", "json:output"),
("json_pretty_print", "json:pretty_print", "boolean"),
("json_show_contexts", "json:show_contexts", "boolean"),
('json_output', 'json:output'),
('json_pretty_print', 'json:pretty_print', 'boolean'),
('json_show_contexts', 'json:show_contexts', 'boolean'),
# [lcov]
("lcov_output", "lcov:output"),
('lcov_output', 'lcov:output'),
]
def _set_attr_from_config_option(
@ -435,16 +437,16 @@ class CoverageConfig(TConfigurable, TPluginConfig):
cp: TConfigParser,
attr: str,
where: str,
type_: str = "",
type_: str = '',
) -> bool:
"""Set an attribute on self if it exists in the ConfigParser.
Returns True if the attribute was set.
"""
section, option = where.split(":")
section, option = where.split(':')
if cp.has_option(section, option):
method = getattr(cp, "get" + type_)
method = getattr(cp, 'get' + type_)
setattr(self, attr, method(section, option))
return True
return False
@ -464,7 +466,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
"""
# Special-cased options.
if option_name == "paths":
if option_name == 'paths':
self.paths = value # type: ignore[assignment]
return
@ -476,13 +478,13 @@ class CoverageConfig(TConfigurable, TPluginConfig):
return
# See if it's a plugin option.
plugin_name, _, key = option_name.partition(":")
plugin_name, _, key = option_name.partition(':')
if key and plugin_name in self.plugins:
self.plugin_options.setdefault(plugin_name, {})[key] = value # type: ignore[index]
self.plugin_options.setdefault(plugin_name, {})[key] = value # type: ignore[index]
return
# If we get here, we didn't find the option.
raise ConfigError(f"No such option: {option_name!r}")
raise ConfigError(f'No such option: {option_name!r}')
def get_option(self, option_name: str) -> TConfigValueOut | None:
"""Get an option from the configuration.
@ -495,7 +497,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
"""
# Special-cased options.
if option_name == "paths":
if option_name == 'paths':
return self.paths # type: ignore[return-value]
# Check all the hard-coded options.
@ -505,12 +507,12 @@ class CoverageConfig(TConfigurable, TPluginConfig):
return getattr(self, attr) # type: ignore[no-any-return]
# See if it's a plugin option.
plugin_name, _, key = option_name.partition(":")
plugin_name, _, key = option_name.partition(':')
if key and plugin_name in self.plugins:
return self.plugin_options.get(plugin_name, {}).get(key)
# If we get here, we didn't find the option.
raise ConfigError(f"No such option: {option_name!r}")
raise ConfigError(f'No such option: {option_name!r}')
def post_process_file(self, path: str) -> str:
"""Make final adjustments to a file path to make it usable."""
@ -530,7 +532,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
def debug_info(self) -> list[tuple[str, Any]]:
"""Make a list of (name, value) pairs for writing debug info."""
return human_sorted_items(
(k, v) for k, v in self.__dict__.items() if not k.startswith("_")
(k, v) for k, v in self.__dict__.items() if not k.startswith('_')
)
@ -543,24 +545,24 @@ def config_files_to_try(config_file: bool | str) -> list[tuple[str, bool, bool]]
# Some API users were specifying ".coveragerc" to mean the same as
# True, so make it so.
if config_file == ".coveragerc":
if config_file == '.coveragerc':
config_file = True
specified_file = (config_file is not True)
if not specified_file:
# No file was specified. Check COVERAGE_RCFILE.
rcfile = os.getenv("COVERAGE_RCFILE")
rcfile = os.getenv('COVERAGE_RCFILE')
if rcfile:
config_file = rcfile
specified_file = True
if not specified_file:
# Still no file specified. Default to .coveragerc
config_file = ".coveragerc"
config_file = '.coveragerc'
assert isinstance(config_file, str)
files_to_try = [
(config_file, True, specified_file),
("setup.cfg", False, False),
("tox.ini", False, False),
("pyproject.toml", False, False),
('setup.cfg', False, False),
('tox.ini', False, False),
('pyproject.toml', False, False),
]
return files_to_try
@ -601,13 +603,13 @@ def read_coverage_config(
raise ConfigError(f"Couldn't read {fname!r} as a config file")
# 3) from environment variables:
env_data_file = os.getenv("COVERAGE_FILE")
env_data_file = os.getenv('COVERAGE_FILE')
if env_data_file:
config.data_file = env_data_file
# $set_env.py: COVERAGE_DEBUG - Debug options: https://coverage.rtfd.io/cmd.html#debug
debugs = os.getenv("COVERAGE_DEBUG")
debugs = os.getenv('COVERAGE_DEBUG')
if debugs:
config.debug.extend(d.strip() for d in debugs.split(","))
config.debug.extend(d.strip() for d in debugs.split(','))
# 4) from constructor arguments:
config.from_args(**kwargs)

View file

@ -1,12 +1,12 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Determine contexts for coverage.py"""
from __future__ import annotations
from types import FrameType
from typing import cast, Callable, Sequence
from typing import Callable
from typing import cast
from typing import Sequence
def combine_context_switchers(
@ -44,7 +44,7 @@ def combine_context_switchers(
def should_start_context_test_function(frame: FrameType) -> str | None:
"""Is this frame calling a test_* function?"""
co_name = frame.f_code.co_name
if co_name.startswith("test") or co_name == "runTest":
if co_name.startswith('test') or co_name == 'runTest':
return qualname_from_frame(frame)
return None
@ -54,19 +54,19 @@ def qualname_from_frame(frame: FrameType) -> str | None:
co = frame.f_code
fname = co.co_name
method = None
if co.co_argcount and co.co_varnames[0] == "self":
self = frame.f_locals.get("self", None)
if co.co_argcount and co.co_varnames[0] == 'self':
self = frame.f_locals.get('self', None)
method = getattr(self, fname, None)
if method is None:
func = frame.f_globals.get(fname)
if func is None:
return None
return cast(str, func.__module__ + "." + fname)
return cast(str, func.__module__ + '.' + fname)
func = getattr(method, "__func__", None)
func = getattr(method, '__func__', None)
if func is None:
cls = self.__class__
return cast(str, cls.__module__ + "." + cls.__name__ + "." + fname)
return cast(str, cls.__module__ + '.' + cls.__name__ + '.' + fname)
return cast(str, func.__module__ + "." + func.__qualname__)
return cast(str, func.__module__ + '.' + func.__qualname__)

View file

@ -1,14 +1,11 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Central control stuff for coverage.py."""
from __future__ import annotations
import atexit
import collections
import contextlib
import os
import os.path
import platform
import signal
@ -16,31 +13,48 @@ import sys
import threading
import time
import warnings
from types import FrameType
from typing import (
cast,
Any, Callable, IO, Iterable, Iterator, List,
)
from typing import Any
from typing import Callable
from typing import cast
from typing import IO
from typing import Iterable
from typing import Iterator
from typing import List
from coverage import env
from coverage.annotate import AnnotateReporter
from coverage.collector import Collector, HAS_CTRACER
from coverage.config import CoverageConfig, read_coverage_config
from coverage.context import should_start_context_test_function, combine_context_switchers
from coverage.data import CoverageData, combine_parallel_data
from coverage.debug import (
DebugControl, NoDebugging, short_stack, write_formatted_info, relevant_environment_display,
)
from coverage.collector import Collector
from coverage.collector import HAS_CTRACER
from coverage.config import CoverageConfig
from coverage.config import read_coverage_config
from coverage.context import combine_context_switchers
from coverage.context import should_start_context_test_function
from coverage.data import combine_parallel_data
from coverage.data import CoverageData
from coverage.debug import DebugControl
from coverage.debug import NoDebugging
from coverage.debug import relevant_environment_display
from coverage.debug import short_stack
from coverage.debug import write_formatted_info
from coverage.disposition import disposition_debug_msg
from coverage.exceptions import ConfigError, CoverageException, CoverageWarning, PluginError
from coverage.files import PathAliases, abs_file, relative_filename, set_relative_directory
from coverage.exceptions import ConfigError
from coverage.exceptions import CoverageException
from coverage.exceptions import CoverageWarning
from coverage.exceptions import PluginError
from coverage.files import abs_file
from coverage.files import PathAliases
from coverage.files import relative_filename
from coverage.files import set_relative_directory
from coverage.html import HtmlReporter
from coverage.inorout import InOrOut
from coverage.jsonreport import JsonReporter
from coverage.lcovreport import LcovReporter
from coverage.misc import bool_or_none, join_regex
from coverage.misc import DefaultValue, ensure_dir_for_file, isolate_module
from coverage.misc import bool_or_none
from coverage.misc import DefaultValue
from coverage.misc import ensure_dir_for_file
from coverage.misc import isolate_module
from coverage.misc import join_regex
from coverage.multiproc import patch_multiprocessing
from coverage.plugin import FileReporter
from coverage.plugin_support import Plugins
@ -48,14 +62,19 @@ from coverage.python import PythonFileReporter
from coverage.report import SummaryReporter
from coverage.report_core import render_report
from coverage.results import Analysis
from coverage.types import (
FilePath, TConfigurable, TConfigSectionIn, TConfigValueIn, TConfigValueOut,
TFileDisposition, TLineNo, TMorf,
)
from coverage.types import FilePath
from coverage.types import TConfigSectionIn
from coverage.types import TConfigurable
from coverage.types import TConfigValueIn
from coverage.types import TConfigValueOut
from coverage.types import TFileDisposition
from coverage.types import TLineNo
from coverage.types import TMorf
from coverage.xmlreport import XmlReporter
os = isolate_module(os)
@contextlib.contextmanager
def override_config(cov: Coverage, **kwargs: TConfigValueIn) -> Iterator[None]:
"""Temporarily tweak the configuration of `cov`.
@ -72,9 +91,10 @@ def override_config(cov: Coverage, **kwargs: TConfigValueIn) -> Iterator[None]:
cov.config = original_config
DEFAULT_DATAFILE = DefaultValue("MISSING")
DEFAULT_DATAFILE = DefaultValue('MISSING')
_DEFAULT_DATAFILE = DEFAULT_DATAFILE # Just in case, for backwards compatibility
class Coverage(TConfigurable):
"""Programmatic access to coverage.py.
@ -323,10 +343,10 @@ class Coverage(TConfigurable):
# Create and configure the debugging controller.
self._debug = DebugControl(self.config.debug, self._debug_file, self.config.debug_file)
if self._debug.should("process"):
self._debug.write("Coverage._init")
if self._debug.should('process'):
self._debug.write('Coverage._init')
if "multiprocessing" in (self.config.concurrency or ()):
if 'multiprocessing' in (self.config.concurrency or ()):
# Multi-processing uses parallel for the subprocesses, so also use
# it for the main process.
self.config.parallel = True
@ -358,31 +378,31 @@ class Coverage(TConfigurable):
# "[run] _crash" will raise an exception if the value is close by in
# the call stack, for testing error handling.
if self.config._crash and self.config._crash in short_stack():
raise RuntimeError(f"Crashing because called by {self.config._crash}")
raise RuntimeError(f'Crashing because called by {self.config._crash}')
def _write_startup_debug(self) -> None:
"""Write out debug info at startup if needed."""
wrote_any = False
with self._debug.without_callers():
if self._debug.should("config"):
if self._debug.should('config'):
config_info = self.config.debug_info()
write_formatted_info(self._debug.write, "config", config_info)
write_formatted_info(self._debug.write, 'config', config_info)
wrote_any = True
if self._debug.should("sys"):
write_formatted_info(self._debug.write, "sys", self.sys_info())
if self._debug.should('sys'):
write_formatted_info(self._debug.write, 'sys', self.sys_info())
for plugin in self._plugins:
header = "sys: " + plugin._coverage_plugin_name
header = 'sys: ' + plugin._coverage_plugin_name
info = plugin.sys_info()
write_formatted_info(self._debug.write, header, info)
wrote_any = True
if self._debug.should("pybehave"):
write_formatted_info(self._debug.write, "pybehave", env.debug_info())
if self._debug.should('pybehave'):
write_formatted_info(self._debug.write, 'pybehave', env.debug_info())
wrote_any = True
if wrote_any:
write_formatted_info(self._debug.write, "end", ())
write_formatted_info(self._debug.write, 'end', ())
def _should_trace(self, filename: str, frame: FrameType) -> TFileDisposition:
"""Decide whether to trace execution in `filename`.
@ -392,7 +412,7 @@ class Coverage(TConfigurable):
"""
assert self._inorout is not None
disp = self._inorout.should_trace(filename, frame)
if self._debug.should("trace"):
if self._debug.should('trace'):
self._debug.write(disposition_debug_msg(disp))
return disp
@ -404,11 +424,11 @@ class Coverage(TConfigurable):
"""
assert self._inorout is not None
reason = self._inorout.check_include_omit_etc(filename, frame)
if self._debug.should("trace"):
if self._debug.should('trace'):
if not reason:
msg = f"Including {filename!r}"
msg = f'Including {filename!r}'
else:
msg = f"Not including {filename!r}: {reason}"
msg = f'Not including {filename!r}: {reason}'
self._debug.write(msg)
return not reason
@ -431,9 +451,9 @@ class Coverage(TConfigurable):
self._warnings.append(msg)
if slug:
msg = f"{msg} ({slug})"
if self._debug.should("pid"):
msg = f"[{os.getpid()}] {msg}"
msg = f'{msg} ({slug})'
if self._debug.should('pid'):
msg = f'[{os.getpid()}] {msg}'
warnings.warn(msg, category=CoverageWarning, stacklevel=2)
if once:
@ -512,15 +532,15 @@ class Coverage(TConfigurable):
"""Initialization for start()"""
# Construct the collector.
concurrency: list[str] = self.config.concurrency or []
if "multiprocessing" in concurrency:
if 'multiprocessing' in concurrency:
if self.config.config_file is None:
raise ConfigError("multiprocessing requires a configuration file")
raise ConfigError('multiprocessing requires a configuration file')
patch_multiprocessing(rcfile=self.config.config_file)
dycon = self.config.dynamic_context
if not dycon or dycon == "none":
if not dycon or dycon == 'none':
context_switchers = []
elif dycon == "test_function":
elif dycon == 'test_function':
context_switchers = [should_start_context_test_function]
else:
raise ConfigError(f"Don't understand dynamic_context setting: {dycon!r}")
@ -565,9 +585,9 @@ class Coverage(TConfigurable):
if self._plugins.file_tracers and not self._collector.supports_plugins:
self._warn(
"Plugin file tracers ({}) aren't supported with {}".format(
", ".join(
', '.join(
plugin._coverage_plugin_name
for plugin in self._plugins.file_tracers
for plugin in self._plugins.file_tracers
),
self._collector.tracer_name(),
),
@ -579,7 +599,7 @@ class Coverage(TConfigurable):
self._inorout = InOrOut(
config=self.config,
warn=self._warn,
debug=(self._debug if self._debug.should("trace") else None),
debug=(self._debug if self._debug.should('trace') else None),
include_namespace_packages=self.config.include_namespace_packages,
)
self._inorout.plugins = self._plugins
@ -676,18 +696,18 @@ class Coverage(TConfigurable):
finally:
self.stop()
def _atexit(self, event: str = "atexit") -> None:
def _atexit(self, event: str = 'atexit') -> None:
"""Clean up on process shutdown."""
if self._debug.should("process"):
self._debug.write(f"{event}: pid: {os.getpid()}, instance: {self!r}")
if self._debug.should('process'):
self._debug.write(f'{event}: pid: {os.getpid()}, instance: {self!r}')
if self._started:
self.stop()
if self._auto_save or event == "sigterm":
if self._auto_save or event == 'sigterm':
self.save()
def _on_sigterm(self, signum_unused: int, frame_unused: FrameType | None) -> None:
"""A handler for signal.SIGTERM."""
self._atexit("sigterm")
self._atexit('sigterm')
# Statements after here won't be seen by metacov because we just wrote
# the data, and are about to kill the process.
signal.signal(signal.SIGTERM, self._old_sigterm) # pragma: not covered
@ -724,21 +744,21 @@ class Coverage(TConfigurable):
"""
if not self._started: # pragma: part started
raise CoverageException("Cannot switch context, coverage is not started")
raise CoverageException('Cannot switch context, coverage is not started')
assert self._collector is not None
if self._collector.should_start_context:
self._warn("Conflicting dynamic contexts", slug="dynamic-conflict", once=True)
self._warn('Conflicting dynamic contexts', slug='dynamic-conflict', once=True)
self._collector.switch_context(new_context)
def clear_exclude(self, which: str = "exclude") -> None:
def clear_exclude(self, which: str = 'exclude') -> None:
"""Clear the exclude list."""
self._init()
setattr(self.config, which + "_list", [])
setattr(self.config, which + '_list', [])
self._exclude_regex_stale()
def exclude(self, regex: str, which: str = "exclude") -> None:
def exclude(self, regex: str, which: str = 'exclude') -> None:
"""Exclude source lines from execution consideration.
A number of lists of regular expressions are maintained. Each list
@ -754,7 +774,7 @@ class Coverage(TConfigurable):
"""
self._init()
excl_list = getattr(self.config, which + "_list")
excl_list = getattr(self.config, which + '_list')
excl_list.append(regex)
self._exclude_regex_stale()
@ -765,11 +785,11 @@ class Coverage(TConfigurable):
def _exclude_regex(self, which: str) -> str:
"""Return a regex string for the given exclusion list."""
if which not in self._exclude_re:
excl_list = getattr(self.config, which + "_list")
excl_list = getattr(self.config, which + '_list')
self._exclude_re[which] = join_regex(excl_list)
return self._exclude_re[which]
def get_exclude_list(self, which: str = "exclude") -> list[str]:
def get_exclude_list(self, which: str = 'exclude') -> list[str]:
"""Return a list of excluded regex strings.
`which` indicates which list is desired. See :meth:`exclude` for the
@ -777,7 +797,7 @@ class Coverage(TConfigurable):
"""
self._init()
return cast(List[str], getattr(self.config, which + "_list"))
return cast(List[str], getattr(self.config, which + '_list'))
def save(self) -> None:
"""Save the collected coverage data to the data file."""
@ -787,7 +807,7 @@ class Coverage(TConfigurable):
def _make_aliases(self) -> PathAliases:
"""Create a PathAliases from our configuration."""
aliases = PathAliases(
debugfn=(self._debug.write if self._debug.should("pathmap") else None),
debugfn=(self._debug.write if self._debug.should('pathmap') else None),
relative=self.config.relative_files,
)
for paths in self.config.paths.values():
@ -884,7 +904,7 @@ class Coverage(TConfigurable):
# Find out if we got any data.
if not self._data and self._warn_no_data:
self._warn("No data was collected.", slug="no-data-collected")
self._warn('No data was collected.', slug='no-data-collected')
# Touch all the files that could have executed, so that we can
# mark completely un-executed files as 0% covered.
@ -952,7 +972,7 @@ class Coverage(TConfigurable):
"""Get a FileReporter for a module or file name."""
assert self._data is not None
plugin = None
file_reporter: str | FileReporter = "python"
file_reporter: str | FileReporter = 'python'
if isinstance(morf, str):
mapped_morf = self._file_mapper(morf)
@ -964,12 +984,12 @@ class Coverage(TConfigurable):
file_reporter = plugin.file_reporter(mapped_morf)
if file_reporter is None:
raise PluginError(
"Plugin {!r} did not provide a file reporter for {!r}.".format(
'Plugin {!r} did not provide a file reporter for {!r}.'.format(
plugin._coverage_plugin_name, morf,
),
)
if file_reporter == "python":
if file_reporter == 'python':
file_reporter = PythonFileReporter(morf, self)
assert isinstance(file_reporter, FileReporter)
@ -1290,36 +1310,37 @@ class Coverage(TConfigurable):
for plugin in plugins:
entry = plugin._coverage_plugin_name
if not plugin._coverage_enabled:
entry += " (disabled)"
entry += ' (disabled)'
entries.append(entry)
return entries
info = [
("coverage_version", covmod.__version__),
("coverage_module", covmod.__file__),
("core", self._collector.tracer_name() if self._collector is not None else "-none-"),
("CTracer", "available" if HAS_CTRACER else "unavailable"),
("plugins.file_tracers", plugin_info(self._plugins.file_tracers)),
("plugins.configurers", plugin_info(self._plugins.configurers)),
("plugins.context_switchers", plugin_info(self._plugins.context_switchers)),
("configs_attempted", self.config.attempted_config_files),
("configs_read", self.config.config_files_read),
("config_file", self.config.config_file),
("config_contents",
repr(self.config._config_contents) if self.config._config_contents else "-none-",
('coverage_version', covmod.__version__),
('coverage_module', covmod.__file__),
('core', self._collector.tracer_name() if self._collector is not None else '-none-'),
('CTracer', 'available' if HAS_CTRACER else 'unavailable'),
('plugins.file_tracers', plugin_info(self._plugins.file_tracers)),
('plugins.configurers', plugin_info(self._plugins.configurers)),
('plugins.context_switchers', plugin_info(self._plugins.context_switchers)),
('configs_attempted', self.config.attempted_config_files),
('configs_read', self.config.config_files_read),
('config_file', self.config.config_file),
(
'config_contents',
repr(self.config._config_contents) if self.config._config_contents else '-none-',
),
("data_file", self._data.data_filename() if self._data is not None else "-none-"),
("python", sys.version.replace("\n", "")),
("platform", platform.platform()),
("implementation", platform.python_implementation()),
("executable", sys.executable),
("def_encoding", sys.getdefaultencoding()),
("fs_encoding", sys.getfilesystemencoding()),
("pid", os.getpid()),
("cwd", os.getcwd()),
("path", sys.path),
("environment", [f"{k} = {v}" for k, v in relevant_environment_display(os.environ)]),
("command_line", " ".join(getattr(sys, "argv", ["-none-"]))),
('data_file', self._data.data_filename() if self._data is not None else '-none-'),
('python', sys.version.replace('\n', '')),
('platform', platform.platform()),
('implementation', platform.python_implementation()),
('executable', sys.executable),
('def_encoding', sys.getdefaultencoding()),
('fs_encoding', sys.getfilesystemencoding()),
('pid', os.getpid()),
('cwd', os.getcwd()),
('path', sys.path),
('environment', [f'{k} = {v}' for k, v in relevant_environment_display(os.environ)]),
('command_line', ' '.join(getattr(sys, 'argv', ['-none-']))),
]
if self._inorout is not None:
@ -1332,12 +1353,12 @@ class Coverage(TConfigurable):
# Mega debugging...
# $set_env.py: COVERAGE_DEBUG_CALLS - Lots and lots of output about calls to Coverage.
if int(os.getenv("COVERAGE_DEBUG_CALLS", 0)): # pragma: debugging
if int(os.getenv('COVERAGE_DEBUG_CALLS', 0)): # pragma: debugging
from coverage.debug import decorate_methods, show_calls
Coverage = decorate_methods( # type: ignore[misc]
show_calls(show_args=True),
butnot=["get_data"],
butnot=['get_data'],
)(Coverage)
@ -1364,7 +1385,7 @@ def process_startup() -> Coverage | None:
not started by this call.
"""
cps = os.getenv("COVERAGE_PROCESS_START")
cps = os.getenv('COVERAGE_PROCESS_START')
if not cps:
# No request for coverage, nothing to do.
return None
@ -1378,7 +1399,7 @@ def process_startup() -> Coverage | None:
#
# https://github.com/nedbat/coveragepy/issues/340 has more details.
if hasattr(process_startup, "coverage"):
if hasattr(process_startup, 'coverage'):
# We've annotated this function before, so we must have already
# started coverage.py in this process. Nothing to do.
return None
@ -1396,6 +1417,6 @@ def process_startup() -> Coverage | None:
def _prevent_sub_process_measurement() -> None:
"""Stop any subprocess auto-measurement from writing data."""
auto_created_coverage = getattr(process_startup, "coverage", None)
auto_created_coverage = getattr(process_startup, 'coverage', None)
if auto_created_coverage is not None:
auto_created_coverage._auto_save = False

View file

@ -1,6 +1,5 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Coverage data for coverage.py.
This file had the 4.x JSON data support, which is now gone. This file still
@ -9,18 +8,21 @@ CoverageData is now defined in sqldata.py, and imported here to keep the
imports working.
"""
from __future__ import annotations
import glob
import hashlib
import os.path
from typing import Callable
from typing import Iterable
from typing import Callable, Iterable
from coverage.exceptions import CoverageException, NoDataError
from coverage.exceptions import CoverageException
from coverage.exceptions import NoDataError
from coverage.files import PathAliases
from coverage.misc import Hasher, file_be_gone, human_sorted, plural
from coverage.misc import file_be_gone
from coverage.misc import Hasher
from coverage.misc import human_sorted
from coverage.misc import plural
from coverage.sqldata import CoverageData
@ -38,7 +40,7 @@ def line_counts(data: CoverageData, fullpath: bool = False) -> dict[str, int]:
filename_fn: Callable[[str], str]
if fullpath:
# pylint: disable=unnecessary-lambda-assignment
filename_fn = lambda f: f
def filename_fn(f): return f
else:
filename_fn = os.path.basename
for filename in data.measured_files():
@ -79,14 +81,14 @@ def combinable_files(data_file: str, data_paths: Iterable[str] | None = None) ->
if os.path.isfile(p):
files_to_combine.append(os.path.abspath(p))
elif os.path.isdir(p):
pattern = glob.escape(os.path.join(os.path.abspath(p), local)) +".*"
pattern = glob.escape(os.path.join(os.path.abspath(p), local)) + '.*'
files_to_combine.extend(glob.glob(pattern))
else:
raise NoDataError(f"Couldn't combine from non-existent path '{p}'")
# SQLite might have made journal files alongside our database files.
# We never want to combine those.
files_to_combine = [fnm for fnm in files_to_combine if not fnm.endswith("-journal")]
files_to_combine = [fnm for fnm in files_to_combine if not fnm.endswith('-journal')]
# Sorting isn't usually needed, since it shouldn't matter what order files
# are combined, but sorting makes tests more predictable, and makes
@ -132,7 +134,7 @@ def combine_parallel_data(
files_to_combine = combinable_files(data.base_filename(), data_paths)
if strict and not files_to_combine:
raise NoDataError("No data to combine")
raise NoDataError('No data to combine')
file_hashes = set()
combined_any = False
@ -141,8 +143,8 @@ def combine_parallel_data(
if f == data.data_filename():
# Sometimes we are combining into a file which is one of the
# parallel files. Skip that file.
if data._debug.should("dataio"):
data._debug.write(f"Skipping combining ourself: {f!r}")
if data._debug.should('dataio'):
data._debug.write(f'Skipping combining ourself: {f!r}')
continue
try:
@ -153,16 +155,16 @@ def combine_parallel_data(
# we print the original value of f instead of its relative path
rel_file_name = f
with open(f, "rb") as fobj:
hasher = hashlib.new("sha3_256")
with open(f, 'rb') as fobj:
hasher = hashlib.new('sha3_256')
hasher.update(fobj.read())
sha = hasher.digest()
combine_this_one = sha not in file_hashes
delete_this_one = not keep
if combine_this_one:
if data._debug.should("dataio"):
data._debug.write(f"Combining data file {f!r}")
if data._debug.should('dataio'):
data._debug.write(f'Combining data file {f!r}')
file_hashes.add(sha)
try:
new_data = CoverageData(f, debug=data._debug)
@ -179,39 +181,39 @@ def combine_parallel_data(
data.update(new_data, aliases=aliases)
combined_any = True
if message:
message(f"Combined data file {rel_file_name}")
message(f'Combined data file {rel_file_name}')
else:
if message:
message(f"Skipping duplicate data {rel_file_name}")
message(f'Skipping duplicate data {rel_file_name}')
if delete_this_one:
if data._debug.should("dataio"):
data._debug.write(f"Deleting data file {f!r}")
if data._debug.should('dataio'):
data._debug.write(f'Deleting data file {f!r}')
file_be_gone(f)
if strict and not combined_any:
raise NoDataError("No usable data files")
raise NoDataError('No usable data files')
def debug_data_file(filename: str) -> None:
"""Implementation of 'coverage debug data'."""
data = CoverageData(filename)
filename = data.data_filename()
print(f"path: {filename}")
print(f'path: {filename}')
if not os.path.exists(filename):
print("No data collected: file doesn't exist")
return
data.read()
print(f"has_arcs: {data.has_arcs()!r}")
print(f'has_arcs: {data.has_arcs()!r}')
summary = line_counts(data, fullpath=True)
filenames = human_sorted(summary.keys())
nfiles = len(filenames)
print(f"{nfiles} file{plural(nfiles)}:")
print(f'{nfiles} file{plural(nfiles)}:')
for f in filenames:
line = f"{f}: {summary[f]} line{plural(summary[f])}"
line = f'{f}: {summary[f]} line{plural(summary[f])}'
plugin = data.file_tracer(f)
if plugin:
line += f" [{plugin}]"
line += f' [{plugin}]'
print(line)

View file

@ -1,10 +1,9 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Control of and utilities for debugging."""
from __future__ import annotations
import _thread
import atexit
import contextlib
import functools
@ -17,15 +16,18 @@ import reprlib
import sys
import traceback
import types
import _thread
from typing import Any
from typing import Callable
from typing import IO
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import overload
from typing import (
overload,
Any, Callable, IO, Iterable, Iterator, Mapping,
)
from coverage.misc import human_sorted_items, isolate_module
from coverage.types import AnyCallable, TWritable
from coverage.misc import human_sorted_items
from coverage.misc import isolate_module
from coverage.types import AnyCallable
from coverage.types import TWritable
os = isolate_module(os)
@ -53,12 +55,12 @@ class DebugControl:
self.suppress_callers = False
filters = []
if self.should("process"):
if self.should('process'):
filters.append(CwdTracker().filter)
filters.append(ProcessTracker().filter)
if self.should("pytest"):
if self.should('pytest'):
filters.append(PytestTracker().filter)
if self.should("pid"):
if self.should('pid'):
filters.append(add_pid_and_tid)
self.output = DebugOutputFile.get_one(
@ -69,11 +71,11 @@ class DebugControl:
self.raw_output = self.output.outfile
def __repr__(self) -> str:
return f"<DebugControl options={self.options!r} raw_output={self.raw_output!r}>"
return f'<DebugControl options={self.options!r} raw_output={self.raw_output!r}>'
def should(self, option: str) -> bool:
"""Decide whether to output debug information in category `option`."""
if option == "callers" and self.suppress_callers:
if option == 'callers' and self.suppress_callers:
return False
return (option in self.options)
@ -96,20 +98,21 @@ class DebugControl:
after the message.
"""
self.output.write(msg + "\n")
self.output.write(msg + '\n')
if exc is not None:
self.output.write("".join(traceback.format_exception(None, exc, exc.__traceback__)))
if self.should("self"):
caller_self = inspect.stack()[1][0].f_locals.get("self")
self.output.write(''.join(traceback.format_exception(None, exc, exc.__traceback__)))
if self.should('self'):
caller_self = inspect.stack()[1][0].f_locals.get('self')
if caller_self is not None:
self.output.write(f"self: {caller_self!r}\n")
if self.should("callers"):
self.output.write(f'self: {caller_self!r}\n')
if self.should('callers'):
dump_stack_frames(out=self.output, skip=1)
self.output.flush()
class NoDebugging(DebugControl):
"""A replacement for DebugControl that will never try to do anything."""
def __init__(self) -> None:
# pylint: disable=super-init-not-called
...
@ -120,12 +123,12 @@ class NoDebugging(DebugControl):
def write(self, msg: str, *, exc: BaseException | None = None) -> None:
"""This will never be called."""
raise AssertionError("NoDebugging.write should never be called.")
raise AssertionError('NoDebugging.write should never be called.')
def info_header(label: str) -> str:
"""Make a nice header string."""
return "--{:-<60s}".format(" "+label+" ")
return '--{:-<60s}'.format(' ' + label + ' ')
def info_formatter(info: Iterable[tuple[str, Any]]) -> Iterator[str]:
@ -142,17 +145,17 @@ def info_formatter(info: Iterable[tuple[str, Any]]) -> Iterator[str]:
assert all(len(l) < label_len for l, _ in info)
for label, data in info:
if data == []:
data = "-none-"
data = '-none-'
if isinstance(data, tuple) and len(repr(tuple(data))) < 30:
# Convert to tuple to scrub namedtuples.
yield "%*s: %r" % (label_len, label, tuple(data))
yield '%*s: %r' % (label_len, label, tuple(data))
elif isinstance(data, (list, set, tuple)):
prefix = "%*s:" % (label_len, label)
prefix = '%*s:' % (label_len, label)
for e in data:
yield "%*s %s" % (label_len+1, prefix, e)
prefix = ""
yield '%*s %s' % (label_len + 1, prefix, e)
prefix = ''
else:
yield "%*s: %s" % (label_len, label, data)
yield '%*s: %s' % (label_len, label, data)
def write_formatted_info(
@ -170,35 +173,38 @@ def write_formatted_info(
"""
write(info_header(header))
for line in info_formatter(info):
write(f" {line}")
write(f' {line}')
def exc_one_line(exc: Exception) -> str:
"""Get a one-line summary of an exception, including class name and message."""
lines = traceback.format_exception_only(type(exc), exc)
return "|".join(l.rstrip() for l in lines)
return '|'.join(l.rstrip() for l in lines)
_FILENAME_REGEXES: list[tuple[str, str]] = [
(r".*[/\\]pytest-of-.*[/\\]pytest-\d+([/\\]popen-gw\d+)?", "tmp:"),
(r'.*[/\\]pytest-of-.*[/\\]pytest-\d+([/\\]popen-gw\d+)?', 'tmp:'),
]
_FILENAME_SUBS: list[tuple[str, str]] = []
@overload
def short_filename(filename: str) -> str:
pass
@overload
def short_filename(filename: None) -> None:
pass
def short_filename(filename: str | None) -> str | None:
"""Shorten a file name. Directories are replaced by prefixes like 'syspath:'"""
if not _FILENAME_SUBS:
for pathdir in sys.path:
_FILENAME_SUBS.append((pathdir, "syspath:"))
_FILENAME_SUBS.append((pathdir, 'syspath:'))
import coverage
_FILENAME_SUBS.append((os.path.dirname(coverage.__file__), "cov:"))
_FILENAME_SUBS.append((os.path.dirname(coverage.__file__), 'cov:'))
_FILENAME_SUBS.sort(key=(lambda pair: len(pair[0])), reverse=True)
if filename is not None:
for pat, sub in _FILENAME_REGEXES:
@ -237,9 +243,9 @@ def short_stack(
"""
# Regexes in initial frames that we don't care about.
BORING_PRELUDE = [
"<string>", # pytest-xdist has string execution.
r"\bigor.py$", # Our test runner.
r"\bsite-packages\b", # pytest etc getting to our tests.
'<string>', # pytest-xdist has string execution.
r'\bigor.py$', # Our test runner.
r'\bsite-packages\b', # pytest etc getting to our tests.
]
stack: Iterable[inspect.FrameInfo] = inspect.stack()[:skip:-1]
@ -251,20 +257,20 @@ def short_stack(
)
lines = []
for frame_info in stack:
line = f"{frame_info.function:>30s} : "
line = f'{frame_info.function:>30s} : '
if frame_ids:
line += f"{id(frame_info.frame):#x} "
line += f'{id(frame_info.frame):#x} '
filename = frame_info.filename
if short_filenames:
filename = short_filename(filename)
line += f"{filename}:{frame_info.lineno}"
line += f'{filename}:{frame_info.lineno}'
lines.append(line)
return "\n".join(lines)
return '\n'.join(lines)
def dump_stack_frames(out: TWritable, skip: int = 0) -> None:
"""Print a summary of the stack to `out`."""
out.write(short_stack(skip=skip+1) + "\n")
out.write(short_stack(skip=skip + 1) + '\n')
def clipped_repr(text: str, numchars: int = 50) -> str:
@ -285,36 +291,37 @@ def short_id(id64: int) -> int:
def add_pid_and_tid(text: str) -> str:
"""A filter to add pid and tid to debug messages."""
# Thread ids are useful, but too long. Make a shorter one.
tid = f"{short_id(_thread.get_ident()):04x}"
text = f"{os.getpid():5d}.{tid}: {text}"
tid = f'{short_id(_thread.get_ident()):04x}'
text = f'{os.getpid():5d}.{tid}: {text}'
return text
AUTO_REPR_IGNORE = {"$coverage.object_id"}
AUTO_REPR_IGNORE = {'$coverage.object_id'}
def auto_repr(self: Any) -> str:
"""A function implementing an automatic __repr__ for debugging."""
show_attrs = (
(k, v) for k, v in self.__dict__.items()
if getattr(v, "show_repr_attr", True)
and not inspect.ismethod(v)
and k not in AUTO_REPR_IGNORE
if getattr(v, 'show_repr_attr', True) and
not inspect.ismethod(v) and
k not in AUTO_REPR_IGNORE
)
return "<{klass} @{id:#x}{attrs}>".format(
return '<{klass} @{id:#x}{attrs}>'.format(
klass=self.__class__.__name__,
id=id(self),
attrs="".join(f" {k}={v!r}" for k, v in show_attrs),
attrs=''.join(f' {k}={v!r}' for k, v in show_attrs),
)
def simplify(v: Any) -> Any: # pragma: debugging
"""Turn things which are nearly dict/list/etc into dict/list/etc."""
if isinstance(v, dict):
return {k:simplify(vv) for k, vv in v.items()}
return {k: simplify(vv) for k, vv in v.items()}
elif isinstance(v, (list, tuple)):
return type(v)(simplify(vv) for vv in v)
elif hasattr(v, "__dict__"):
return simplify({"."+k: v for k, v in v.__dict__.items()})
elif hasattr(v, '__dict__'):
return simplify({'.' + k: v for k, v in v.__dict__.items()})
else:
return v
@ -343,12 +350,13 @@ def filter_text(text: str, filters: Iterable[Callable[[str], str]]) -> str:
lines = []
for line in text.splitlines():
lines.extend(filter_fn(line).splitlines())
text = "\n".join(lines)
text = '\n'.join(lines)
return text + ending
class CwdTracker:
"""A class to add cwd info to debug messages."""
def __init__(self) -> None:
self.cwd: str | None = None
@ -356,32 +364,33 @@ class CwdTracker:
"""Add a cwd message for each new cwd."""
cwd = os.getcwd()
if cwd != self.cwd:
text = f"cwd is now {cwd!r}\n" + text
text = f'cwd is now {cwd!r}\n' + text
self.cwd = cwd
return text
class ProcessTracker:
"""Track process creation for debug logging."""
def __init__(self) -> None:
self.pid: int = os.getpid()
self.did_welcome = False
def filter(self, text: str) -> str:
"""Add a message about how new processes came to be."""
welcome = ""
welcome = ''
pid = os.getpid()
if self.pid != pid:
welcome = f"New process: forked {self.pid} -> {pid}\n"
welcome = f'New process: forked {self.pid} -> {pid}\n'
self.pid = pid
elif not self.did_welcome:
argv = getattr(sys, "argv", None)
argv = getattr(sys, 'argv', None)
welcome = (
f"New process: {pid=}, executable: {sys.executable!r}\n"
+ f"New process: cmd: {argv!r}\n"
f'New process: {pid=}, executable: {sys.executable!r}\n' +
f'New process: cmd: {argv!r}\n'
)
if hasattr(os, "getppid"):
welcome += f"New process parent pid: {os.getppid()!r}\n"
if hasattr(os, 'getppid'):
welcome += f'New process parent pid: {os.getppid()!r}\n'
if welcome:
self.did_welcome = True
@ -392,20 +401,22 @@ class ProcessTracker:
class PytestTracker:
"""Track the current pytest test name to add to debug messages."""
def __init__(self) -> None:
self.test_name: str | None = None
def filter(self, text: str) -> str:
"""Add a message when the pytest test changes."""
test_name = os.getenv("PYTEST_CURRENT_TEST")
test_name = os.getenv('PYTEST_CURRENT_TEST')
if test_name != self.test_name:
text = f"Pytest context: {test_name}\n" + text
text = f'Pytest context: {test_name}\n' + text
self.test_name = test_name
return text
class DebugOutputFile:
"""A file-like object that includes pid and cwd information."""
def __init__(
self,
outfile: IO[str] | None,
@ -444,21 +455,21 @@ class DebugOutputFile:
the_one, is_interim = cls._get_singleton_data()
if the_one is None or is_interim:
if file_name is not None:
fileobj = open(file_name, "a", encoding="utf-8")
fileobj = open(file_name, 'a', encoding='utf-8')
else:
# $set_env.py: COVERAGE_DEBUG_FILE - Where to write debug output
file_name = os.getenv("COVERAGE_DEBUG_FILE", FORCED_DEBUG_FILE)
if file_name in ("stdout", "stderr"):
file_name = os.getenv('COVERAGE_DEBUG_FILE', FORCED_DEBUG_FILE)
if file_name in ('stdout', 'stderr'):
fileobj = getattr(sys, file_name)
elif file_name:
fileobj = open(file_name, "a", encoding="utf-8")
fileobj = open(file_name, 'a', encoding='utf-8')
atexit.register(fileobj.close)
else:
fileobj = sys.stderr
the_one = cls(fileobj, filters)
cls._set_singleton_data(the_one, interim)
if not(the_one.filters):
if not (the_one.filters):
the_one.filters = list(filters)
return the_one
@ -467,8 +478,8 @@ class DebugOutputFile:
# a process-wide singleton. So stash it in sys.modules instead of
# on a class attribute. Yes, this is aggressively gross.
SYS_MOD_NAME = "$coverage.debug.DebugOutputFile.the_one"
SINGLETON_ATTR = "the_one_and_is_interim"
SYS_MOD_NAME = '$coverage.debug.DebugOutputFile.the_one'
SINGLETON_ATTR = 'the_one_and_is_interim'
@classmethod
def _set_singleton_data(cls, the_one: DebugOutputFile, interim: bool) -> None:
@ -504,7 +515,7 @@ class DebugOutputFile:
def log(msg: str, stack: bool = False) -> None: # pragma: debugging
"""Write a log message as forcefully as possible."""
out = DebugOutputFile.get_one(interim=True)
out.write(msg+"\n")
out.write(msg + '\n')
if stack:
dump_stack_frames(out=out, skip=1)
@ -519,8 +530,8 @@ def decorate_methods(
for name, meth in inspect.getmembers(cls, inspect.isroutine):
if name not in cls.__dict__:
continue
if name != "__init__":
if not private and name.startswith("_"):
if name != '__init__':
if not private and name.startswith('_'):
continue
if name in butnot:
continue
@ -542,7 +553,8 @@ def break_in_pudb(func: AnyCallable) -> AnyCallable: # pragma: debugging
OBJ_IDS = itertools.count()
CALLS = itertools.count()
OBJ_ID_ATTR = "$coverage.object_id"
OBJ_ID_ATTR = '$coverage.object_id'
def show_calls(
show_args: bool = True,
@ -555,27 +567,27 @@ def show_calls(
def _wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
oid = getattr(self, OBJ_ID_ATTR, None)
if oid is None:
oid = f"{os.getpid():08d} {next(OBJ_IDS):04d}"
oid = f'{os.getpid():08d} {next(OBJ_IDS):04d}'
setattr(self, OBJ_ID_ATTR, oid)
extra = ""
extra = ''
if show_args:
eargs = ", ".join(map(repr, args))
ekwargs = ", ".join("{}={!r}".format(*item) for item in kwargs.items())
extra += "("
eargs = ', '.join(map(repr, args))
ekwargs = ', '.join('{}={!r}'.format(*item) for item in kwargs.items())
extra += '('
extra += eargs
if eargs and ekwargs:
extra += ", "
extra += ', '
extra += ekwargs
extra += ")"
extra += ')'
if show_stack:
extra += " @ "
extra += "; ".join(short_stack(short_filenames=True).splitlines())
extra += ' @ '
extra += '; '.join(short_stack(short_filenames=True).splitlines())
callid = next(CALLS)
msg = f"{oid} {callid:04d} {func.__name__}{extra}\n"
msg = f'{oid} {callid:04d} {func.__name__}{extra}\n'
DebugOutputFile.get_one(interim=True).write(msg)
ret = func(self, *args, **kwargs)
if show_return:
msg = f"{oid} {callid:04d} {func.__name__} return {ret!r}\n"
msg = f'{oid} {callid:04d} {func.__name__} return {ret!r}\n'
DebugOutputFile.get_one(interim=True).write(msg)
return ret
return _wrapper
@ -595,9 +607,9 @@ def relevant_environment_display(env: Mapping[str, str]) -> list[tuple[str, str]
A list of pairs (name, value) to show.
"""
slugs = {"COV", "PY"}
include = {"HOME", "TEMP", "TMP"}
cloak = {"API", "TOKEN", "KEY", "SECRET", "PASS", "SIGNATURE"}
slugs = {'COV', 'PY'}
include = {'HOME', 'TEMP', 'TMP'}
cloak = {'API', 'TOKEN', 'KEY', 'SECRET', 'PASS', 'SIGNATURE'}
to_show = []
for name, val in env.items():
@ -608,6 +620,6 @@ def relevant_environment_display(env: Mapping[str, str]) -> list[tuple[str, str]
keep = True
if keep:
if any(slug in name for slug in cloak):
val = re.sub(r"\w", "*", val)
val = re.sub(r'\w', '*', val)
to_show.append((name, val))
return human_sorted_items(to_show)

View file

@ -1,8 +1,6 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Simple value objects for tracking what to do with files."""
from __future__ import annotations
from typing import TYPE_CHECKING
@ -25,7 +23,7 @@ class FileDisposition:
has_dynamic_filename: bool
def __repr__(self) -> str:
return f"<FileDisposition {self.canonical_filename!r}: trace={self.trace}>"
return f'<FileDisposition {self.canonical_filename!r}: trace={self.trace}>'
# FileDisposition "methods": FileDisposition is a pure value object, so it can
@ -39,7 +37,7 @@ def disposition_init(cls: type[TFileDisposition], original_filename: str) -> TFi
disp.canonical_filename = original_filename
disp.source_filename = None
disp.trace = False
disp.reason = ""
disp.reason = ''
disp.file_tracer = None
disp.has_dynamic_filename = False
return disp
@ -48,11 +46,11 @@ def disposition_init(cls: type[TFileDisposition], original_filename: str) -> TFi
def disposition_debug_msg(disp: TFileDisposition) -> str:
"""Make a nice debug message of what the FileDisposition is doing."""
if disp.trace:
msg = f"Tracing {disp.original_filename!r}"
msg = f'Tracing {disp.original_filename!r}'
if disp.original_filename != disp.source_filename:
msg += f" as {disp.source_filename!r}"
msg += f' as {disp.source_filename!r}'
if disp.file_tracer:
msg += f": will be traced by {disp.file_tracer!r}"
msg += f': will be traced by {disp.file_tracer!r}'
else:
msg = f"Not tracing {disp.original_filename!r}: {disp.reason}"
msg = f'Not tracing {disp.original_filename!r}: {disp.reason}'
return msg

View file

@ -1,48 +1,48 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Determine facts about the environment."""
from __future__ import annotations
import os
import platform
import sys
from typing import Any, Iterable
from typing import Any
from typing import Iterable
# debug_info() at the bottom wants to show all the globals, but not imports.
# Grab the global names here to know which names to not show. Nothing defined
# above this line will be in the output.
_UNINTERESTING_GLOBALS = list(globals())
# These names also shouldn't be shown.
_UNINTERESTING_GLOBALS += ["PYBEHAVIOR", "debug_info"]
_UNINTERESTING_GLOBALS += ['PYBEHAVIOR', 'debug_info']
# Operating systems.
WINDOWS = sys.platform == "win32"
LINUX = sys.platform.startswith("linux")
OSX = sys.platform == "darwin"
WINDOWS = sys.platform == 'win32'
LINUX = sys.platform.startswith('linux')
OSX = sys.platform == 'darwin'
# Python implementations.
CPYTHON = (platform.python_implementation() == "CPython")
PYPY = (platform.python_implementation() == "PyPy")
CPYTHON = (platform.python_implementation() == 'CPython')
PYPY = (platform.python_implementation() == 'PyPy')
# Python versions. We amend version_info with one more value, a zero if an
# official version, or 1 if built from source beyond an official version.
# Only use sys.version_info directly where tools like mypy need it to understand
# version-specfic code, otherwise use PYVERSION.
PYVERSION = sys.version_info + (int(platform.python_version()[-1] == "+"),)
PYVERSION = sys.version_info + (int(platform.python_version()[-1] == '+'),)
if PYPY:
PYPYVERSION = sys.pypy_version_info # type: ignore[attr-defined]
# Python behavior.
class PYBEHAVIOR:
"""Flags indicating this Python's behavior."""
# Does Python conform to PEP626, Precise line numbers for debugging and other tools.
# https://www.python.org/dev/peps/pep-0626
pep626 = (PYVERSION > (3, 10, 0, "alpha", 4))
pep626 = (PYVERSION > (3, 10, 0, 'alpha', 4))
# Is "if __debug__" optimized away?
optimize_if_debug = not pep626
@ -69,19 +69,19 @@ class PYBEHAVIOR:
# CPython 3.11 now jumps to the decorator line again while executing
# the decorator.
trace_decorator_line_again = (CPYTHON and PYVERSION > (3, 11, 0, "alpha", 3, 0))
trace_decorator_line_again = (CPYTHON and PYVERSION > (3, 11, 0, 'alpha', 3, 0))
# CPython 3.9a1 made sys.argv[0] and other reported files absolute paths.
report_absolute_files = (
(CPYTHON or (PYPY and PYPYVERSION >= (7, 3, 10)))
and PYVERSION >= (3, 9)
(CPYTHON or (PYPY and PYPYVERSION >= (7, 3, 10))) and
PYVERSION >= (3, 9)
)
# Lines after break/continue/return/raise are no longer compiled into the
# bytecode. They used to be marked as missing, now they aren't executable.
omit_after_jump = (
pep626
or (PYPY and PYVERSION >= (3, 9) and PYPYVERSION >= (7, 3, 12))
pep626 or
(PYPY and PYVERSION >= (3, 9) and PYPYVERSION >= (7, 3, 12))
)
# PyPy has always omitted statements after return.
@ -98,7 +98,7 @@ class PYBEHAVIOR:
keep_constant_test = pep626
# When leaving a with-block, do we visit the with-line again for the exit?
exit_through_with = (PYVERSION >= (3, 10, 0, "beta"))
exit_through_with = (PYVERSION >= (3, 10, 0, 'beta'))
# Match-case construct.
match_case = (PYVERSION >= (3, 10))
@ -108,14 +108,14 @@ class PYBEHAVIOR:
# Modules start with a line numbered zero. This means empty modules have
# only a 0-number line, which is ignored, giving a truly empty module.
empty_is_empty = (PYVERSION >= (3, 11, 0, "beta", 4))
empty_is_empty = (PYVERSION >= (3, 11, 0, 'beta', 4))
# Are comprehensions inlined (new) or compiled as called functions (old)?
# Changed in https://github.com/python/cpython/pull/101441
comprehensions_are_functions = (PYVERSION <= (3, 12, 0, "alpha", 7, 0))
comprehensions_are_functions = (PYVERSION <= (3, 12, 0, 'alpha', 7, 0))
# PEP669 Low Impact Monitoring: https://peps.python.org/pep-0669/
pep669 = bool(getattr(sys, "monitoring", None))
pep669 = bool(getattr(sys, 'monitoring', None))
# Where does frame.f_lasti point when yielding from a generator?
# It used to point at the YIELD, now it points at the RESUME.
@ -126,22 +126,22 @@ class PYBEHAVIOR:
# Coverage.py specifics, about testing scenarios. See tests/testenv.py also.
# Are we coverage-measuring ourselves?
METACOV = os.getenv("COVERAGE_COVERAGE") is not None
METACOV = os.getenv('COVERAGE_COVERAGE') is not None
# Are we running our test suite?
# Even when running tests, you can use COVERAGE_TESTING=0 to disable the
# test-specific behavior like AST checking.
TESTING = os.getenv("COVERAGE_TESTING") == "True"
TESTING = os.getenv('COVERAGE_TESTING') == 'True'
def debug_info() -> Iterable[tuple[str, Any]]:
"""Return a list of (name, value) pairs for printing debug information."""
info = [
(name, value) for name, value in globals().items()
if not name.startswith("_") and name not in _UNINTERESTING_GLOBALS
if not name.startswith('_') and name not in _UNINTERESTING_GLOBALS
]
info += [
(name, value) for name, value in PYBEHAVIOR.__dict__.items()
if not name.startswith("_")
if not name.startswith('_')
]
return sorted(info)

View file

@ -1,10 +1,9 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Exceptions coverage.py can raise."""
from __future__ import annotations
class _BaseCoverageException(Exception):
"""The base-base of all Coverage exceptions."""
pass
@ -24,6 +23,7 @@ class DataError(CoverageException):
"""An error in using a data file."""
pass
class NoDataError(CoverageException):
"""We didn't have data to work with."""
pass

View file

@ -1,8 +1,6 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Execute files of Python code."""
from __future__ import annotations
import importlib.machinery
@ -12,14 +10,18 @@ import marshal
import os
import struct
import sys
from importlib.machinery import ModuleSpec
from types import CodeType, ModuleType
from types import CodeType
from types import ModuleType
from typing import Any
from coverage import env
from coverage.exceptions import CoverageException, _ExceptionDuringRun, NoCode, NoSource
from coverage.files import canonical_filename, python_reported_file
from coverage.exceptions import _ExceptionDuringRun
from coverage.exceptions import CoverageException
from coverage.exceptions import NoCode
from coverage.exceptions import NoSource
from coverage.files import canonical_filename
from coverage.files import python_reported_file
from coverage.misc import isolate_module
from coverage.python import get_python_source
@ -28,11 +30,13 @@ os = isolate_module(os)
PYC_MAGIC_NUMBER = importlib.util.MAGIC_NUMBER
class DummyLoader:
"""A shim for the pep302 __loader__, emulating pkgutil.ImpLoader.
Currently only implements the .fullname attribute
"""
def __init__(self, fullname: str, *_args: Any) -> None:
self.fullname = fullname
@ -50,20 +54,20 @@ def find_module(
except ImportError as err:
raise NoSource(str(err)) from err
if not spec:
raise NoSource(f"No module named {modulename!r}")
raise NoSource(f'No module named {modulename!r}')
pathname = spec.origin
packagename = spec.name
if spec.submodule_search_locations:
mod_main = modulename + ".__main__"
mod_main = modulename + '.__main__'
spec = importlib.util.find_spec(mod_main)
if not spec:
raise NoSource(
f"No module named {mod_main}; " +
f"{modulename!r} is a package and cannot be directly executed",
f'No module named {mod_main}; ' +
f'{modulename!r} is a package and cannot be directly executed',
)
pathname = spec.origin
packagename = spec.name
packagename = packagename.rpartition(".")[0]
packagename = packagename.rpartition('.')[0]
return pathname, packagename, spec
@ -73,6 +77,7 @@ class PyRunner:
This is meant to emulate real Python execution as closely as possible.
"""
def __init__(self, args: list[str], as_module: bool = False) -> None:
self.args = args
self.as_module = as_module
@ -142,8 +147,8 @@ class PyRunner:
elif os.path.isdir(self.arg0):
# Running a directory means running the __main__.py file in that
# directory.
for ext in [".py", ".pyc", ".pyo"]:
try_filename = os.path.join(self.arg0, "__main__" + ext)
for ext in ['.py', '.pyc', '.pyo']:
try_filename = os.path.join(self.arg0, '__main__' + ext)
# 3.8.10 changed how files are reported when running a
# directory. But I'm not sure how far this change is going to
# spread, so I'll just hard-code it here for now.
@ -157,12 +162,12 @@ class PyRunner:
# Make a spec. I don't know if this is the right way to do it.
try_filename = python_reported_file(try_filename)
self.spec = importlib.machinery.ModuleSpec("__main__", None, origin=try_filename)
self.spec = importlib.machinery.ModuleSpec('__main__', None, origin=try_filename)
self.spec.has_location = True
self.package = ""
self.loader = DummyLoader("__main__")
self.package = ''
self.loader = DummyLoader('__main__')
else:
self.loader = DummyLoader("__main__")
self.loader = DummyLoader('__main__')
self.arg0 = python_reported_file(self.arg0)
@ -172,9 +177,9 @@ class PyRunner:
self._prepare2()
# Create a module to serve as __main__
main_mod = ModuleType("__main__")
main_mod = ModuleType('__main__')
from_pyc = self.arg0.endswith((".pyc", ".pyo"))
from_pyc = self.arg0.endswith(('.pyc', '.pyo'))
main_mod.__file__ = self.arg0
if from_pyc:
main_mod.__file__ = main_mod.__file__[:-1]
@ -184,9 +189,9 @@ class PyRunner:
if self.spec is not None:
main_mod.__spec__ = self.spec
main_mod.__builtins__ = sys.modules["builtins"] # type: ignore[attr-defined]
main_mod.__builtins__ = sys.modules['builtins'] # type: ignore[attr-defined]
sys.modules["__main__"] = main_mod
sys.modules['__main__'] = main_mod
# Set sys.argv properly.
sys.argv = self.args
@ -228,7 +233,7 @@ class PyRunner:
# is non-None when the exception is reported at the upper layer,
# and a nested exception is shown to the user. This getattr fixes
# it somehow? https://bitbucket.org/pypy/pypy/issue/1903
getattr(err, "__context__", None)
getattr(err, '__context__', None)
# Call the excepthook.
try:
@ -240,7 +245,7 @@ class PyRunner:
except Exception as exc:
# Getting the output right in the case of excepthook
# shenanigans is kind of involved.
sys.stderr.write("Error in sys.excepthook:\n")
sys.stderr.write('Error in sys.excepthook:\n')
typ2, err2, tb2 = sys.exc_info()
assert typ2 is not None
assert err2 is not None
@ -249,7 +254,7 @@ class PyRunner:
assert err2.__traceback__ is not None
err2.__traceback__ = err2.__traceback__.tb_next
sys.__excepthook__(typ2, err2, tb2.tb_next)
sys.stderr.write("\nOriginal exception was:\n")
sys.stderr.write('\nOriginal exception was:\n')
raise _ExceptionDuringRun(typ, err, tb.tb_next) from exc
else:
sys.exit(1)
@ -294,13 +299,13 @@ def make_code_from_py(filename: str) -> CodeType:
except (OSError, NoSource) as exc:
raise NoSource(f"No file to run: '{filename}'") from exc
return compile(source, filename, "exec", dont_inherit=True)
return compile(source, filename, 'exec', dont_inherit=True)
def make_code_from_pyc(filename: str) -> CodeType:
"""Get a code object from a .pyc file."""
try:
fpyc = open(filename, "rb")
fpyc = open(filename, 'rb')
except OSError as exc:
raise NoCode(f"No file to run: '{filename}'") from exc
@ -309,9 +314,9 @@ def make_code_from_pyc(filename: str) -> CodeType:
# match or we won't run the file.
magic = fpyc.read(4)
if magic != PYC_MAGIC_NUMBER:
raise NoCode(f"Bad magic number in .pyc file: {magic!r} != {PYC_MAGIC_NUMBER!r}")
raise NoCode(f'Bad magic number in .pyc file: {magic!r} != {PYC_MAGIC_NUMBER!r}')
flags = struct.unpack("<L", fpyc.read(4))[0]
flags = struct.unpack('<L', fpyc.read(4))[0]
hash_based = flags & 0x01
if hash_based:
fpyc.read(8) # Skip the hash.

View file

@ -1,31 +1,31 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""File wrangling."""
from __future__ import annotations
import hashlib
import ntpath
import os
import os.path
import posixpath
import re
import sys
from typing import Callable, Iterable
from typing import Callable
from typing import Iterable
from coverage import env
from coverage.exceptions import ConfigError
from coverage.misc import human_sorted, isolate_module, join_regex
from coverage.misc import human_sorted
from coverage.misc import isolate_module
from coverage.misc import join_regex
os = isolate_module(os)
RELATIVE_DIR: str = ""
RELATIVE_DIR: str = ''
CANONICAL_FILENAME_CACHE: dict[str, str] = {}
def set_relative_directory() -> None:
"""Set the directory that `relative_filename` will be relative to."""
global RELATIVE_DIR, CANONICAL_FILENAME_CACHE
@ -73,7 +73,7 @@ def canonical_filename(filename: str) -> str:
if not os.path.isabs(filename):
for path in [os.curdir] + sys.path:
if path is None:
continue # type: ignore[unreachable]
continue # type: ignore[unreachable]
f = os.path.join(path, filename)
try:
exists = os.path.exists(f)
@ -89,6 +89,7 @@ def canonical_filename(filename: str) -> str:
MAX_FLAT = 100
def flat_rootname(filename: str) -> str:
"""A base for a flat file name to correspond to this file.
@ -101,11 +102,11 @@ def flat_rootname(filename: str) -> str:
"""
dirname, basename = ntpath.split(filename)
if dirname:
fp = hashlib.new("sha3_256", dirname.encode("UTF-8")).hexdigest()[:16]
prefix = f"d_{fp}_"
fp = hashlib.new('sha3_256', dirname.encode('UTF-8')).hexdigest()[:16]
prefix = f'd_{fp}_'
else:
prefix = ""
return prefix + basename.replace(".", "_")
prefix = ''
return prefix + basename.replace('.', '_')
if env.WINDOWS:
@ -163,7 +164,7 @@ def zip_location(filename: str) -> tuple[str, str] | None:
name is in the zipfile.
"""
for ext in [".zip", ".whl", ".egg", ".pex"]:
for ext in ['.zip', '.whl', '.egg', '.pex']:
zipbase, extension, inner = filename.partition(ext + sep(filename))
if extension:
zipfile = zipbase + ext
@ -210,7 +211,7 @@ def prep_patterns(patterns: Iterable[str]) -> list[str]:
prepped = []
for p in patterns or []:
prepped.append(p)
if not p.startswith(("*", "?")):
if not p.startswith(('*', '?')):
prepped.append(abs_file(p))
return prepped
@ -223,14 +224,15 @@ class TreeMatcher:
somewhere in a subtree rooted at one of the directories.
"""
def __init__(self, paths: Iterable[str], name: str = "unknown") -> None:
def __init__(self, paths: Iterable[str], name: str = 'unknown') -> None:
self.original_paths: list[str] = human_sorted(paths)
#self.paths = list(map(os.path.normcase, paths))
self.paths = [os.path.normcase(p) for p in paths]
self.name = name
def __repr__(self) -> str:
return f"<TreeMatcher {self.name} {self.original_paths!r}>"
return f'<TreeMatcher {self.name} {self.original_paths!r}>'
def info(self) -> list[str]:
"""A list of strings for displaying when dumping state."""
@ -252,12 +254,13 @@ class TreeMatcher:
class ModuleMatcher:
"""A matcher for modules in a tree."""
def __init__(self, module_names: Iterable[str], name:str = "unknown") -> None:
def __init__(self, module_names: Iterable[str], name: str = 'unknown') -> None:
self.modules = list(module_names)
self.name = name
def __repr__(self) -> str:
return f"<ModuleMatcher {self.name} {self.modules!r}>"
return f'<ModuleMatcher {self.name} {self.modules!r}>'
def info(self) -> list[str]:
"""A list of strings for displaying when dumping state."""
@ -272,7 +275,7 @@ class ModuleMatcher:
if module_name.startswith(m):
if module_name == m:
return True
if module_name[len(m)] == ".":
if module_name[len(m)] == '.':
# This is a module in the package
return True
@ -281,13 +284,14 @@ class ModuleMatcher:
class GlobMatcher:
"""A matcher for files by file name pattern."""
def __init__(self, pats: Iterable[str], name: str = "unknown") -> None:
def __init__(self, pats: Iterable[str], name: str = 'unknown') -> None:
self.pats = list(pats)
self.re = globs_to_regex(self.pats, case_insensitive=env.WINDOWS)
self.name = name
def __repr__(self) -> str:
return f"<GlobMatcher {self.name} {self.pats!r}>"
return f'<GlobMatcher {self.name} {self.pats!r}>'
def info(self) -> list[str]:
"""A list of strings for displaying when dumping state."""
@ -300,7 +304,7 @@ class GlobMatcher:
def sep(s: str) -> str:
"""Find the path separator used in this string, or os.sep if none."""
if sep_match := re.search(r"[\\/]", s):
if sep_match := re.search(r'[\\/]', s):
the_sep = sep_match[0]
else:
the_sep = os.sep
@ -309,29 +313,32 @@ def sep(s: str) -> str:
# Tokenizer for _glob_to_regex.
# None as a sub means disallowed.
G2RX_TOKENS = [(re.compile(rx), sub) for rx, sub in [
(r"\*\*\*+", None), # Can't have ***
(r"[^/]+\*\*+", None), # Can't have x**
(r"\*\*+[^/]+", None), # Can't have **x
(r"\*\*/\*\*", None), # Can't have **/**
(r"^\*+/", r"(.*[/\\\\])?"), # ^*/ matches any prefix-slash, or nothing.
(r"/\*+$", r"[/\\\\].*"), # /*$ matches any slash-suffix.
(r"\*\*/", r"(.*[/\\\\])?"), # **/ matches any subdirs, including none
(r"/", r"[/\\\\]"), # / matches either slash or backslash
(r"\*", r"[^/\\\\]*"), # * matches any number of non slash-likes
(r"\?", r"[^/\\\\]"), # ? matches one non slash-like
(r"\[.*?\]", r"\g<0>"), # [a-f] matches [a-f]
(r"[a-zA-Z0-9_-]+", r"\g<0>"), # word chars match themselves
(r"[\[\]]", None), # Can't have single square brackets
(r".", r"\\\g<0>"), # Anything else is escaped to be safe
]]
G2RX_TOKENS = [
(re.compile(rx), sub) for rx, sub in [
(r'\*\*\*+', None), # Can't have ***
(r'[^/]+\*\*+', None), # Can't have x**
(r'\*\*+[^/]+', None), # Can't have **x
(r'\*\*/\*\*', None), # Can't have **/**
(r'^\*+/', r'(.*[/\\\\])?'), # ^*/ matches any prefix-slash, or nothing.
(r'/\*+$', r'[/\\\\].*'), # /*$ matches any slash-suffix.
(r'\*\*/', r'(.*[/\\\\])?'), # **/ matches any subdirs, including none
(r'/', r'[/\\\\]'), # / matches either slash or backslash
(r'\*', r'[^/\\\\]*'), # * matches any number of non slash-likes
(r'\?', r'[^/\\\\]'), # ? matches one non slash-like
(r'\[.*?\]', r'\g<0>'), # [a-f] matches [a-f]
(r'[a-zA-Z0-9_-]+', r'\g<0>'), # word chars match themselves
(r'[\[\]]', None), # Can't have single square brackets
(r'.', r'\\\g<0>'), # Anything else is escaped to be safe
]
]
def _glob_to_regex(pattern: str) -> str:
"""Convert a file-path glob pattern into a regex."""
# Turn all backslashes into slashes to simplify the tokenizer.
pattern = pattern.replace("\\", "/")
if "/" not in pattern:
pattern = "**/" + pattern
pattern = pattern.replace('\\', '/')
if '/' not in pattern:
pattern = '**/' + pattern
path_rx = []
pos = 0
while pos < len(pattern):
@ -342,7 +349,7 @@ def _glob_to_regex(pattern: str) -> str:
path_rx.append(m.expand(sub))
pos = m.end()
break
return "".join(path_rx)
return ''.join(path_rx)
def globs_to_regex(
@ -371,7 +378,7 @@ def globs_to_regex(
flags |= re.IGNORECASE
rx = join_regex(map(_glob_to_regex, patterns))
if not partial:
rx = fr"(?:{rx})\Z"
rx = fr'(?:{rx})\Z'
compiled = re.compile(rx, flags=flags)
return compiled
@ -387,6 +394,7 @@ class PathAliases:
map a path through those aliases to produce a unified path.
"""
def __init__(
self,
debugfn: Callable[[str], None] | None = None,
@ -400,9 +408,9 @@ class PathAliases:
def pprint(self) -> None:
"""Dump the important parts of the PathAliases, for debugging."""
self.debugfn(f"Aliases (relative={self.relative}):")
self.debugfn(f'Aliases (relative={self.relative}):')
for original_pattern, regex, result in self.aliases:
self.debugfn(f" Rule: {original_pattern!r} -> {result!r} using regex {regex.pattern!r}")
self.debugfn(f' Rule: {original_pattern!r} -> {result!r} using regex {regex.pattern!r}')
def add(self, pattern: str, result: str) -> None:
"""Add the `pattern`/`result` pair to the list of aliases.
@ -421,16 +429,16 @@ class PathAliases:
pattern_sep = sep(pattern)
if len(pattern) > 1:
pattern = pattern.rstrip(r"\/")
pattern = pattern.rstrip(r'\/')
# The pattern can't end with a wildcard component.
if pattern.endswith("*"):
raise ConfigError("Pattern must not end with wildcards.")
if pattern.endswith('*'):
raise ConfigError('Pattern must not end with wildcards.')
# The pattern is meant to match a file path. Let's make it absolute
# unless it already is, or is meant to match any prefix.
if not self.relative:
if not pattern.startswith("*") and not isabs_anywhere(pattern + pattern_sep):
if not pattern.startswith('*') and not isabs_anywhere(pattern + pattern_sep):
pattern = abs_file(pattern)
if not pattern.endswith(pattern_sep):
pattern += pattern_sep
@ -440,10 +448,10 @@ class PathAliases:
# Normalize the result: it must end with a path separator.
result_sep = sep(result)
result = result.rstrip(r"\/") + result_sep
result = result.rstrip(r'\/') + result_sep
self.aliases.append((original_pattern, regex, result))
def map(self, path: str, exists:Callable[[str], bool] = source_exists) -> str:
def map(self, path: str, exists: Callable[[str], bool] = source_exists) -> str:
"""Map `path` through the aliases.
`path` is checked against all of the patterns. The first pattern to
@ -472,18 +480,18 @@ class PathAliases:
new = new.replace(sep(path), sep(result))
if not self.relative:
new = canonical_filename(new)
dot_start = result.startswith(("./", ".\\")) and len(result) > 2
if new.startswith(("./", ".\\")) and not dot_start:
dot_start = result.startswith(('./', '.\\')) and len(result) > 2
if new.startswith(('./', '.\\')) and not dot_start:
new = new[2:]
if not exists(new):
self.debugfn(
f"Rule {original_pattern!r} changed {path!r} to {new!r} " +
f'Rule {original_pattern!r} changed {path!r} to {new!r} ' +
"which doesn't exist, continuing",
)
continue
self.debugfn(
f"Matched path {path!r} to rule {original_pattern!r} -> {result!r}, " +
f"producing {new!r}",
f'Matched path {path!r} to rule {original_pattern!r} -> {result!r}, ' +
f'producing {new!r}',
)
return new
@ -494,21 +502,21 @@ class PathAliases:
if self.relative and not isabs_anywhere(path):
# Auto-generate a pattern to implicitly match relative files
parts = re.split(r"[/\\]", path)
parts = re.split(r'[/\\]', path)
if len(parts) > 1:
dir1 = parts[0]
pattern = f"*/{dir1}"
regex_pat = fr"^(.*[\\/])?{re.escape(dir1)}[\\/]"
result = f"{dir1}{os.sep}"
pattern = f'*/{dir1}'
regex_pat = fr'^(.*[\\/])?{re.escape(dir1)}[\\/]'
result = f'{dir1}{os.sep}'
# Only add a new pattern if we don't already have this pattern.
if not any(p == pattern for p, _, _ in self.aliases):
self.debugfn(
f"Generating rule: {pattern!r} -> {result!r} using regex {regex_pat!r}",
f'Generating rule: {pattern!r} -> {result!r} using regex {regex_pat!r}',
)
self.aliases.append((pattern, re.compile(regex_pat), result))
return self.map(path, exists=exists)
self.debugfn(f"No rules match, path {path!r} is unchanged")
self.debugfn(f'No rules match, path {path!r} is unchanged')
return path
@ -530,7 +538,7 @@ def find_python_files(dirname: str, include_namespace_packages: bool) -> Iterabl
"""
for i, (dirpath, dirnames, filenames) in enumerate(os.walk(dirname)):
if not include_namespace_packages:
if i > 0 and "__init__.py" not in filenames:
if i > 0 and '__init__.py' not in filenames:
# If a directory doesn't have __init__.py, then it isn't
# importable and neither are its files
del dirnames[:]
@ -539,7 +547,7 @@ def find_python_files(dirname: str, include_namespace_packages: bool) -> Iterabl
# We're only interested in files that look like reasonable Python
# files: Must end with .py or .pyw, and must not have certain funny
# characters that probably mean they are editor junk.
if re.match(r"^[^.#~!$@%^&*()+=,]+\.pyw?$", filename):
if re.match(r'^[^.#~!$@%^&*()+=,]+\.pyw?$', filename):
yield os.path.join(dirpath, filename)

View file

@ -1,8 +1,6 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""HTML reporting for coverage.py."""
from __future__ import annotations
import collections
@ -13,20 +11,31 @@ import os
import re
import shutil
import string
from dataclasses import dataclass
from typing import Any, Iterable, TYPE_CHECKING, cast
from typing import Any
from typing import cast
from typing import Iterable
from typing import TYPE_CHECKING
import coverage
from coverage.data import CoverageData, add_data_to_hash
from coverage.data import add_data_to_hash
from coverage.data import CoverageData
from coverage.exceptions import NoDataError
from coverage.files import flat_rootname
from coverage.misc import ensure_dir, file_be_gone, Hasher, isolate_module, format_local_datetime
from coverage.misc import human_sorted, plural, stdout_link
from coverage.misc import ensure_dir
from coverage.misc import file_be_gone
from coverage.misc import format_local_datetime
from coverage.misc import Hasher
from coverage.misc import human_sorted
from coverage.misc import isolate_module
from coverage.misc import plural
from coverage.misc import stdout_link
from coverage.report_core import get_analysis_to_report
from coverage.results import Analysis, Numbers
from coverage.results import Analysis
from coverage.results import Numbers
from coverage.templite import Templite
from coverage.types import TLineNo, TMorf
from coverage.types import TLineNo
from coverage.types import TMorf
from coverage.version import __url__
@ -56,7 +65,7 @@ os = isolate_module(os)
def data_filename(fname: str) -> str:
"""Return the path to an "htmlfiles" data file of ours.
"""
static_dir = os.path.join(os.path.dirname(__file__), "htmlfiles")
static_dir = os.path.join(os.path.dirname(__file__), 'htmlfiles')
static_filename = os.path.join(static_dir, fname)
return static_filename
@ -69,9 +78,9 @@ def read_data(fname: str) -> str:
def write_html(fname: str, html: str) -> None:
"""Write `html` to `fname`, properly encoded."""
html = re.sub(r"(\A\s+)|(\s+$)", "", html, flags=re.MULTILINE) + "\n"
with open(fname, "wb") as fout:
fout.write(html.encode("ascii", "xmlcharrefreplace"))
html = re.sub(r'(\A\s+)|(\s+$)', '', html, flags=re.MULTILINE) + '\n'
with open(fname, 'wb') as fout:
fout.write(html.encode('ascii', 'xmlcharrefreplace'))
@dataclass
@ -86,11 +95,11 @@ class LineData:
context_list: list[str]
short_annotations: list[str]
long_annotations: list[str]
html: str = ""
html: str = ''
context_str: str | None = None
annotate: str | None = None
annotate_long: str | None = None
css_class: str = ""
css_class: str = ''
@dataclass
@ -104,7 +113,7 @@ class FileData:
class HtmlDataGeneration:
"""Generate structured data to be turned into HTML reports."""
EMPTY = "(empty)"
EMPTY = '(empty)'
def __init__(self, cov: Coverage) -> None:
self.coverage = cov
@ -112,8 +121,8 @@ class HtmlDataGeneration:
data = self.coverage.get_data()
self.has_arcs = data.has_arcs()
if self.config.show_contexts:
if data.measured_contexts() == {""}:
self.coverage._warn("No contexts were measured")
if data.measured_contexts() == {''}:
self.coverage._warn('No contexts were measured')
data.set_query_contexts(self.config.report_contexts)
def data_for_file(self, fr: FileReporter, analysis: Analysis) -> FileData:
@ -129,47 +138,49 @@ class HtmlDataGeneration:
for lineno, tokens in enumerate(fr.source_token_lines(), start=1):
# Figure out how to mark this line.
category = ""
category = ''
short_annotations = []
long_annotations = []
if lineno in analysis.excluded:
category = "exc"
category = 'exc'
elif lineno in analysis.missing:
category = "mis"
category = 'mis'
elif self.has_arcs and lineno in missing_branch_arcs:
category = "par"
category = 'par'
for b in missing_branch_arcs[lineno]:
if b < 0:
short_annotations.append("exit")
short_annotations.append('exit')
else:
short_annotations.append(str(b))
long_annotations.append(fr.missing_arc_description(lineno, b, arcs_executed))
elif lineno in analysis.statements:
category = "run"
category = 'run'
contexts = []
contexts_label = ""
contexts_label = ''
context_list = []
if category and self.config.show_contexts:
contexts = human_sorted(c or self.EMPTY for c in contexts_by_lineno.get(lineno, ()))
if contexts == [self.EMPTY]:
contexts_label = self.EMPTY
else:
contexts_label = f"{len(contexts)} ctx"
contexts_label = f'{len(contexts)} ctx'
context_list = contexts
lines.append(LineData(
tokens=tokens,
number=lineno,
category=category,
statement=(lineno in analysis.statements),
contexts=contexts,
contexts_label=contexts_label,
context_list=context_list,
short_annotations=short_annotations,
long_annotations=long_annotations,
))
lines.append(
LineData(
tokens=tokens,
number=lineno,
category=category,
statement=(lineno in analysis.statements),
contexts=contexts,
contexts_label=contexts_label,
context_list=context_list,
short_annotations=short_annotations,
long_annotations=long_annotations,
),
)
file_data = FileData(
relative_filename=fr.relative_filename(),
@ -182,15 +193,17 @@ class HtmlDataGeneration:
class FileToReport:
"""A file we're considering reporting."""
def __init__(self, fr: FileReporter, analysis: Analysis) -> None:
self.fr = fr
self.analysis = analysis
self.rootname = flat_rootname(fr.relative_filename())
self.html_filename = self.rootname + ".html"
self.html_filename = self.rootname + '.html'
HTML_SAFE = string.ascii_letters + string.digits + "!#$%'()*+,-./:;=?@[]^_`{|}~"
@functools.lru_cache(maxsize=None)
def encode_int(n: int) -> str:
"""Create a short HTML-safe string from an integer, using HTML_SAFE."""
@ -201,7 +214,7 @@ def encode_int(n: int) -> str:
while n:
n, t = divmod(n, len(HTML_SAFE))
r.append(HTML_SAFE[t])
return "".join(r)
return ''.join(r)
class HtmlReporter:
@ -210,11 +223,11 @@ class HtmlReporter:
# These files will be copied from the htmlfiles directory to the output
# directory.
STATIC_FILES = [
"style.css",
"coverage_html.js",
"keybd_closed.png",
"keybd_open.png",
"favicon_32.png",
'style.css',
'coverage_html.js',
'keybd_closed.png',
'keybd_open.png',
'favicon_32.png',
]
def __init__(self, cov: Coverage) -> None:
@ -253,29 +266,29 @@ class HtmlReporter:
self.template_globals = {
# Functions available in the templates.
"escape": escape,
"pair": pair,
"len": len,
'escape': escape,
'pair': pair,
'len': len,
# Constants for this report.
"__url__": __url__,
"__version__": coverage.__version__,
"title": title,
"time_stamp": format_local_datetime(datetime.datetime.now()),
"extra_css": self.extra_css,
"has_arcs": self.has_arcs,
"show_contexts": self.config.show_contexts,
'__url__': __url__,
'__version__': coverage.__version__,
'title': title,
'time_stamp': format_local_datetime(datetime.datetime.now()),
'extra_css': self.extra_css,
'has_arcs': self.has_arcs,
'show_contexts': self.config.show_contexts,
# Constants for all reports.
# These css classes determine which lines are highlighted by default.
"category": {
"exc": "exc show_exc",
"mis": "mis show_mis",
"par": "par run show_par",
"run": "run",
'category': {
'exc': 'exc show_exc',
'mis': 'mis show_mis',
'par': 'par run show_par',
'run': 'run',
},
}
self.pyfile_html_source = read_data("pyfile.html")
self.pyfile_html_source = read_data('pyfile.html')
self.source_tmpl = Templite(self.pyfile_html_source, self.template_globals)
def report(self, morfs: Iterable[TMorf] | None) -> float:
@ -303,17 +316,17 @@ class HtmlReporter:
for i, ftr in enumerate(files_to_report):
if i == 0:
prev_html = "index.html"
prev_html = 'index.html'
else:
prev_html = files_to_report[i - 1].html_filename
if i == len(files_to_report) - 1:
next_html = "index.html"
next_html = 'index.html'
else:
next_html = files_to_report[i + 1].html_filename
self.write_html_file(ftr, prev_html, next_html)
if not self.all_files_nums:
raise NoDataError("No data to report.")
raise NoDataError('No data to report.')
self.totals = cast(Numbers, sum(self.all_files_nums))
@ -322,7 +335,7 @@ class HtmlReporter:
first_html = files_to_report[0].html_filename
final_html = files_to_report[-1].html_filename
else:
first_html = final_html = "index.html"
first_html = final_html = 'index.html'
self.index_file(first_html, final_html)
self.make_local_static_report_files()
@ -344,8 +357,8 @@ class HtmlReporter:
# .gitignore can't be copied from the source tree because it would
# prevent the static files from being checked in.
if self.directory_was_empty:
with open(os.path.join(self.directory, ".gitignore"), "w") as fgi:
fgi.write("# Created by coverage.py\n*\n")
with open(os.path.join(self.directory, '.gitignore'), 'w') as fgi:
fgi.write('# Created by coverage.py\n*\n')
# The user may have extra CSS they want copied.
if self.extra_css:
@ -401,29 +414,29 @@ class HtmlReporter:
# Build the HTML for the line.
html_parts = []
for tok_type, tok_text in ldata.tokens:
if tok_type == "ws":
if tok_type == 'ws':
html_parts.append(escape(tok_text))
else:
tok_html = escape(tok_text) or "&nbsp;"
tok_html = escape(tok_text) or '&nbsp;'
html_parts.append(f'<span class="{tok_type}">{tok_html}</span>')
ldata.html = "".join(html_parts)
ldata.html = ''.join(html_parts)
if ldata.context_list:
encoded_contexts = [
encode_int(context_codes[c_context]) for c_context in ldata.context_list
]
code_width = max(len(ec) for ec in encoded_contexts)
ldata.context_str = (
str(code_width)
+ "".join(ec.ljust(code_width) for ec in encoded_contexts)
str(code_width) +
''.join(ec.ljust(code_width) for ec in encoded_contexts)
)
else:
ldata.context_str = ""
ldata.context_str = ''
if ldata.short_annotations:
# 202F is NARROW NO-BREAK SPACE.
# 219B is RIGHTWARDS ARROW WITH STROKE.
ldata.annotate = ",&nbsp;&nbsp; ".join(
f"{ldata.number}&#x202F;&#x219B;&#x202F;{d}"
ldata.annotate = ',&nbsp;&nbsp; '.join(
f'{ldata.number}&#x202F;&#x219B;&#x202F;{d}'
for d in ldata.short_annotations
)
else:
@ -434,10 +447,10 @@ class HtmlReporter:
if len(longs) == 1:
ldata.annotate_long = longs[0]
else:
ldata.annotate_long = "{:d} missed branches: {}".format(
ldata.annotate_long = '{:d} missed branches: {}'.format(
len(longs),
", ".join(
f"{num:d}) {ann_long}"
', '.join(
f'{num:d}) {ann_long}'
for num, ann_long in enumerate(longs, start=1)
),
)
@ -447,24 +460,24 @@ class HtmlReporter:
css_classes = []
if ldata.category:
css_classes.append(
self.template_globals["category"][ldata.category], # type: ignore[index]
self.template_globals['category'][ldata.category], # type: ignore[index]
)
ldata.css_class = " ".join(css_classes) or "pln"
ldata.css_class = ' '.join(css_classes) or 'pln'
html_path = os.path.join(self.directory, ftr.html_filename)
html = self.source_tmpl.render({
**file_data.__dict__,
"contexts_json": contexts_json,
"prev_html": prev_html,
"next_html": next_html,
'contexts_json': contexts_json,
'prev_html': prev_html,
'next_html': next_html,
})
write_html(html_path, html)
# Save this file's information for the index file.
index_info: IndexInfoDict = {
"nums": ftr.analysis.numbers,
"html_filename": ftr.html_filename,
"relative_filename": ftr.fr.relative_filename(),
'nums': ftr.analysis.numbers,
'html_filename': ftr.html_filename,
'relative_filename': ftr.fr.relative_filename(),
}
self.file_summaries.append(index_info)
self.incr.set_index_info(ftr.rootname, index_info)
@ -472,30 +485,30 @@ class HtmlReporter:
def index_file(self, first_html: str, final_html: str) -> None:
"""Write the index.html file for this report."""
self.make_directory()
index_tmpl = Templite(read_data("index.html"), self.template_globals)
index_tmpl = Templite(read_data('index.html'), self.template_globals)
skipped_covered_msg = skipped_empty_msg = ""
skipped_covered_msg = skipped_empty_msg = ''
if self.skipped_covered_count:
n = self.skipped_covered_count
skipped_covered_msg = f"{n} file{plural(n)} skipped due to complete coverage."
skipped_covered_msg = f'{n} file{plural(n)} skipped due to complete coverage.'
if self.skipped_empty_count:
n = self.skipped_empty_count
skipped_empty_msg = f"{n} empty file{plural(n)} skipped."
skipped_empty_msg = f'{n} empty file{plural(n)} skipped.'
html = index_tmpl.render({
"files": self.file_summaries,
"totals": self.totals,
"skipped_covered_msg": skipped_covered_msg,
"skipped_empty_msg": skipped_empty_msg,
"first_html": first_html,
"final_html": final_html,
'files': self.file_summaries,
'totals': self.totals,
'skipped_covered_msg': skipped_covered_msg,
'skipped_empty_msg': skipped_empty_msg,
'first_html': first_html,
'final_html': final_html,
})
index_file = os.path.join(self.directory, "index.html")
index_file = os.path.join(self.directory, 'index.html')
write_html(index_file, html)
print_href = stdout_link(index_file, f"file://{os.path.abspath(index_file)}")
self.coverage._message(f"Wrote HTML report to {print_href}")
print_href = stdout_link(index_file, f'file://{os.path.abspath(index_file)}')
self.coverage._message(f'Wrote HTML report to {print_href}')
# Write the latest hashes for next time.
self.incr.write()
@ -504,12 +517,12 @@ class HtmlReporter:
class IncrementalChecker:
"""Logic and data to support incremental reporting."""
STATUS_FILE = "status.json"
STATUS_FILE = 'status.json'
STATUS_FORMAT = 2
NOTE = (
"This file is an internal implementation detail to speed up HTML report"
+ " generation. Its format can change at any time. You might be looking"
+ " for the JSON report: https://coverage.rtfd.io/cmd.html#cmd-json"
'This file is an internal implementation detail to speed up HTML report' +
' generation. Its format can change at any time. You might be looking' +
' for the JSON report: https://coverage.rtfd.io/cmd.html#cmd-json'
)
# The data looks like:
@ -545,7 +558,7 @@ class IncrementalChecker:
def reset(self) -> None:
"""Initialize to empty. Causes all files to be reported."""
self.globals = ""
self.globals = ''
self.files: dict[str, FileInfoDict] = {}
def read(self) -> None:
@ -559,17 +572,17 @@ class IncrementalChecker:
usable = False
else:
usable = True
if status["format"] != self.STATUS_FORMAT:
if status['format'] != self.STATUS_FORMAT:
usable = False
elif status["version"] != coverage.__version__:
elif status['version'] != coverage.__version__:
usable = False
if usable:
self.files = {}
for filename, fileinfo in status["files"].items():
fileinfo["index"]["nums"] = Numbers(*fileinfo["index"]["nums"])
for filename, fileinfo in status['files'].items():
fileinfo['index']['nums'] = Numbers(*fileinfo['index']['nums'])
self.files[filename] = fileinfo
self.globals = status["globals"]
self.globals = status['globals']
else:
self.reset()
@ -578,19 +591,19 @@ class IncrementalChecker:
status_file = os.path.join(self.directory, self.STATUS_FILE)
files = {}
for filename, fileinfo in self.files.items():
index = fileinfo["index"]
index["nums"] = index["nums"].init_args() # type: ignore[typeddict-item]
index = fileinfo['index']
index['nums'] = index['nums'].init_args() # type: ignore[typeddict-item]
files[filename] = fileinfo
status = {
"note": self.NOTE,
"format": self.STATUS_FORMAT,
"version": coverage.__version__,
"globals": self.globals,
"files": files,
'note': self.NOTE,
'format': self.STATUS_FORMAT,
'version': coverage.__version__,
'globals': self.globals,
'files': files,
}
with open(status_file, "w") as fout:
json.dump(status, fout, separators=(",", ":"))
with open(status_file, 'w') as fout:
json.dump(status, fout, separators=(',', ':'))
def check_global_data(self, *data: Any) -> None:
"""Check the global data that can affect incremental reporting."""
@ -609,7 +622,7 @@ class IncrementalChecker:
`rootname` is the name being used for the file.
"""
m = Hasher()
m.update(fr.source().encode("utf-8"))
m.update(fr.source().encode('utf-8'))
add_data_to_hash(data, fr.filename, m)
this_hash = m.hexdigest()
@ -624,19 +637,19 @@ class IncrementalChecker:
def file_hash(self, fname: str) -> str:
"""Get the hash of `fname`'s contents."""
return self.files.get(fname, {}).get("hash", "") # type: ignore[call-overload]
return self.files.get(fname, {}).get('hash', '') # type: ignore[call-overload]
def set_file_hash(self, fname: str, val: str) -> None:
"""Set the hash of `fname`'s contents."""
self.files.setdefault(fname, {})["hash"] = val # type: ignore[typeddict-item]
self.files.setdefault(fname, {})['hash'] = val # type: ignore[typeddict-item]
def index_info(self, fname: str) -> IndexInfoDict:
"""Get the information for index.html for `fname`."""
return self.files.get(fname, {}).get("index", {}) # type: ignore
return self.files.get(fname, {}).get('index', {}) # type: ignore
def set_index_info(self, fname: str, info: IndexInfoDict) -> None:
"""Set the information for index.html for `fname`."""
self.files.setdefault(fname, {})["index"] = info # type: ignore[typeddict-item]
self.files.setdefault(fname, {})['index'] = info # type: ignore[typeddict-item]
# Helpers for templates and generating HTML
@ -648,9 +661,9 @@ def escape(t: str) -> str:
"""
# Convert HTML special chars into HTML entities.
return t.replace("&", "&amp;").replace("<", "&lt;")
return t.replace('&', '&amp;').replace('<', '&lt;')
def pair(ratio: tuple[int, int]) -> str:
"""Format a pair of numbers so JavaScript can read them in an attribute."""
return "{} {}".format(*ratio)
return '{} {}'.format(*ratio)

View file

@ -1,8 +1,6 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Determining whether files are being measured/reported or not."""
from __future__ import annotations
import importlib.util
@ -14,20 +12,31 @@ import re
import sys
import sysconfig
import traceback
from types import FrameType, ModuleType
from typing import (
cast, Any, Iterable, TYPE_CHECKING,
)
from types import FrameType
from types import ModuleType
from typing import Any
from typing import cast
from typing import Iterable
from typing import TYPE_CHECKING
from coverage import env
from coverage.disposition import FileDisposition, disposition_init
from coverage.exceptions import CoverageException, PluginError
from coverage.files import TreeMatcher, GlobMatcher, ModuleMatcher
from coverage.files import prep_patterns, find_python_files, canonical_filename
from coverage.disposition import disposition_init
from coverage.disposition import FileDisposition
from coverage.exceptions import CoverageException
from coverage.exceptions import PluginError
from coverage.files import canonical_filename
from coverage.files import find_python_files
from coverage.files import GlobMatcher
from coverage.files import ModuleMatcher
from coverage.files import prep_patterns
from coverage.files import TreeMatcher
from coverage.misc import sys_modules_saved
from coverage.python import source_for_file, source_for_morf
from coverage.types import TFileDisposition, TMorf, TWarnFn, TDebugCtl
from coverage.python import source_for_file
from coverage.python import source_for_morf
from coverage.types import TDebugCtl
from coverage.types import TFileDisposition
from coverage.types import TMorf
from coverage.types import TWarnFn
if TYPE_CHECKING:
from coverage.config import CoverageConfig
@ -65,7 +74,7 @@ def canonical_path(morf: TMorf, directory: bool = False) -> str:
"""
morf_path = canonical_filename(source_for_morf(morf))
if morf_path.endswith("__init__.py") or directory:
if morf_path.endswith('__init__.py') or directory:
morf_path = os.path.split(morf_path)[0]
return morf_path
@ -83,16 +92,16 @@ def name_for_module(filename: str, frame: FrameType | None) -> str:
"""
module_globals = frame.f_globals if frame is not None else {}
dunder_name: str = module_globals.get("__name__", None)
dunder_name: str = module_globals.get('__name__', None)
if isinstance(dunder_name, str) and dunder_name != "__main__":
if isinstance(dunder_name, str) and dunder_name != '__main__':
# This is the usual case: an imported module.
return dunder_name
spec = module_globals.get("__spec__", None)
spec = module_globals.get('__spec__', None)
if spec:
fullname = spec.name
if isinstance(fullname, str) and fullname != "__main__":
if isinstance(fullname, str) and fullname != '__main__':
# Module loaded via: runpy -m
return fullname
@ -106,12 +115,12 @@ def name_for_module(filename: str, frame: FrameType | None) -> str:
def module_is_namespace(mod: ModuleType) -> bool:
"""Is the module object `mod` a PEP420 namespace module?"""
return hasattr(mod, "__path__") and getattr(mod, "__file__", None) is None
return hasattr(mod, '__path__') and getattr(mod, '__file__', None) is None
def module_has_file(mod: ModuleType) -> bool:
"""Does the module object `mod` have an existing __file__ ?"""
mod__file__ = getattr(mod, "__file__", None)
mod__file__ = getattr(mod, '__file__', None)
if mod__file__ is None:
return False
return os.path.exists(mod__file__)
@ -146,7 +155,7 @@ def add_stdlib_paths(paths: set[str]) -> None:
# spread across a few locations. Look at all the candidate modules
# we've imported, and take all the different ones.
for m in modules_we_happen_to_have:
if hasattr(m, "__file__"):
if hasattr(m, '__file__'):
paths.add(canonical_path(m, directory=True))
@ -157,10 +166,10 @@ def add_third_party_paths(paths: set[str]) -> None:
for scheme in scheme_names:
# https://foss.heptapod.net/pypy/pypy/-/issues/3433
better_scheme = "pypy_posix" if scheme == "pypy" else scheme
if os.name in better_scheme.split("_"):
better_scheme = 'pypy_posix' if scheme == 'pypy' else scheme
if os.name in better_scheme.split('_'):
config_paths = sysconfig.get_paths(scheme)
for path_name in ["platlib", "purelib", "scripts"]:
for path_name in ['platlib', 'purelib', 'scripts']:
paths.add(config_paths[path_name])
@ -170,7 +179,7 @@ def add_coverage_paths(paths: set[str]) -> None:
paths.add(cover_path)
if env.TESTING:
# Don't include our own test code.
paths.add(os.path.join(cover_path, "tests"))
paths.add(os.path.join(cover_path, 'tests'))
class InOrOut:
@ -221,7 +230,7 @@ class InOrOut:
# The matchers for should_trace.
# Generally useful information
_debug("sys.path:" + "".join(f"\n {p}" for p in sys.path))
_debug('sys.path:' + ''.join(f'\n {p}' for p in sys.path))
# Create the matchers we need for should_trace
self.source_match = None
@ -232,28 +241,28 @@ class InOrOut:
if self.source or self.source_pkgs:
against = []
if self.source:
self.source_match = TreeMatcher(self.source, "source")
against.append(f"trees {self.source_match!r}")
self.source_match = TreeMatcher(self.source, 'source')
against.append(f'trees {self.source_match!r}')
if self.source_pkgs:
self.source_pkgs_match = ModuleMatcher(self.source_pkgs, "source_pkgs")
against.append(f"modules {self.source_pkgs_match!r}")
_debug("Source matching against " + " and ".join(against))
self.source_pkgs_match = ModuleMatcher(self.source_pkgs, 'source_pkgs')
against.append(f'modules {self.source_pkgs_match!r}')
_debug('Source matching against ' + ' and '.join(against))
else:
if self.pylib_paths:
self.pylib_match = TreeMatcher(self.pylib_paths, "pylib")
_debug(f"Python stdlib matching: {self.pylib_match!r}")
self.pylib_match = TreeMatcher(self.pylib_paths, 'pylib')
_debug(f'Python stdlib matching: {self.pylib_match!r}')
if self.include:
self.include_match = GlobMatcher(self.include, "include")
_debug(f"Include matching: {self.include_match!r}")
self.include_match = GlobMatcher(self.include, 'include')
_debug(f'Include matching: {self.include_match!r}')
if self.omit:
self.omit_match = GlobMatcher(self.omit, "omit")
_debug(f"Omit matching: {self.omit_match!r}")
self.omit_match = GlobMatcher(self.omit, 'omit')
_debug(f'Omit matching: {self.omit_match!r}')
self.cover_match = TreeMatcher(self.cover_paths, "coverage")
_debug(f"Coverage code matching: {self.cover_match!r}")
self.cover_match = TreeMatcher(self.cover_paths, 'coverage')
_debug(f'Coverage code matching: {self.cover_match!r}')
self.third_match = TreeMatcher(self.third_paths, "third")
_debug(f"Third-party lib matching: {self.third_match!r}")
self.third_match = TreeMatcher(self.third_paths, 'third')
_debug(f'Third-party lib matching: {self.third_match!r}')
# Check if the source we want to measure has been installed as a
# third-party package.
@ -263,30 +272,30 @@ class InOrOut:
for pkg in self.source_pkgs:
try:
modfile, path = file_and_path_for_module(pkg)
_debug(f"Imported source package {pkg!r} as {modfile!r}")
_debug(f'Imported source package {pkg!r} as {modfile!r}')
except CoverageException as exc:
_debug(f"Couldn't import source package {pkg!r}: {exc}")
continue
if modfile:
if self.third_match.match(modfile):
_debug(
f"Source in third-party: source_pkg {pkg!r} at {modfile!r}",
f'Source in third-party: source_pkg {pkg!r} at {modfile!r}',
)
self.source_in_third_paths.add(canonical_path(source_for_file(modfile)))
else:
for pathdir in path:
if self.third_match.match(pathdir):
_debug(
f"Source in third-party: {pkg!r} path directory at {pathdir!r}",
f'Source in third-party: {pkg!r} path directory at {pathdir!r}',
)
self.source_in_third_paths.add(pathdir)
for src in self.source:
if self.third_match.match(src):
_debug(f"Source in third-party: source directory {src!r}")
_debug(f'Source in third-party: source directory {src!r}')
self.source_in_third_paths.add(src)
self.source_in_third_match = TreeMatcher(self.source_in_third_paths, "source_in_third")
_debug(f"Source in third-party matching: {self.source_in_third_match}")
self.source_in_third_match = TreeMatcher(self.source_in_third_paths, 'source_in_third')
_debug(f'Source in third-party matching: {self.source_in_third_match}')
self.plugins: Plugins
self.disp_class: type[TFileDisposition] = FileDisposition
@ -309,8 +318,8 @@ class InOrOut:
disp.reason = reason
return disp
if original_filename.startswith("<"):
return nope(disp, "original file name is not real")
if original_filename.startswith('<'):
return nope(disp, 'original file name is not real')
if frame is not None:
# Compiled Python files have two file names: frame.f_code.co_filename is
@ -319,10 +328,10 @@ class InOrOut:
# .pyc files can be moved after compilation (for example, by being
# installed), we look for __file__ in the frame and prefer it to the
# co_filename value.
dunder_file = frame.f_globals and frame.f_globals.get("__file__")
dunder_file = frame.f_globals and frame.f_globals.get('__file__')
if dunder_file:
filename = source_for_file(dunder_file)
if original_filename and not original_filename.startswith("<"):
if original_filename and not original_filename.startswith('<'):
orig = os.path.basename(original_filename)
if orig != os.path.basename(filename):
# Files shouldn't be renamed when moved. This happens when
@ -334,15 +343,15 @@ class InOrOut:
# Empty string is pretty useless.
return nope(disp, "empty string isn't a file name")
if filename.startswith("memory:"):
if filename.startswith('memory:'):
return nope(disp, "memory isn't traceable")
if filename.startswith("<"):
if filename.startswith('<'):
# Lots of non-file execution is represented with artificial
# file names like "<string>", "<doctest readme.txt[0]>", or
# "<exec_function>". Don't ever trace these executions, since we
# can't do anything with the data later anyway.
return nope(disp, "file name is not real")
return nope(disp, 'file name is not real')
canonical = canonical_filename(filename)
disp.canonical_filename = canonical
@ -369,7 +378,7 @@ class InOrOut:
except Exception:
plugin_name = plugin._coverage_plugin_name
tb = traceback.format_exc()
self.warn(f"Disabling plug-in {plugin_name!r} due to an exception:\n{tb}")
self.warn(f'Disabling plug-in {plugin_name!r} due to an exception:\n{tb}')
plugin._coverage_enabled = False
continue
else:
@ -402,7 +411,7 @@ class InOrOut:
# any canned exclusions. If they didn't, then we have to exclude the
# stdlib and coverage.py directories.
if self.source_match or self.source_pkgs_match:
extra = ""
extra = ''
ok = False
if self.source_pkgs_match:
if self.source_pkgs_match.match(modulename):
@ -410,41 +419,41 @@ class InOrOut:
if modulename in self.source_pkgs_unmatched:
self.source_pkgs_unmatched.remove(modulename)
else:
extra = f"module {modulename!r} "
extra = f'module {modulename!r} '
if not ok and self.source_match:
if self.source_match.match(filename):
ok = True
if not ok:
return extra + "falls outside the --source spec"
return extra + 'falls outside the --source spec'
if self.third_match.match(filename) and not self.source_in_third_match.match(filename):
return "inside --source, but is third-party"
return 'inside --source, but is third-party'
elif self.include_match:
if not self.include_match.match(filename):
return "falls outside the --include trees"
return 'falls outside the --include trees'
else:
# We exclude the coverage.py code itself, since a little of it
# will be measured otherwise.
if self.cover_match.match(filename):
return "is part of coverage.py"
return 'is part of coverage.py'
# If we aren't supposed to trace installed code, then check if this
# is near the Python standard library and skip it if so.
if self.pylib_match and self.pylib_match.match(filename):
return "is in the stdlib"
return 'is in the stdlib'
# Exclude anything in the third-party installation areas.
if self.third_match.match(filename):
return "is a third-party module"
return 'is a third-party module'
# Check the file against the omit pattern.
if self.omit_match and self.omit_match.match(filename):
return "is inside an --omit pattern"
return 'is inside an --omit pattern'
# No point tracing a file we can't later write to SQLite.
try:
filename.encode("utf-8")
filename.encode('utf-8')
except UnicodeEncodeError:
return "non-encodable filename"
return 'non-encodable filename'
# No reason found to skip this file.
return None
@ -453,20 +462,20 @@ class InOrOut:
"""Warn if there are settings that conflict."""
if self.include:
if self.source or self.source_pkgs:
self.warn("--include is ignored because --source is set", slug="include-ignored")
self.warn('--include is ignored because --source is set', slug='include-ignored')
def warn_already_imported_files(self) -> None:
"""Warn if files have already been imported that we will be measuring."""
if self.include or self.source or self.source_pkgs:
warned = set()
for mod in list(sys.modules.values()):
filename = getattr(mod, "__file__", None)
filename = getattr(mod, '__file__', None)
if filename is None:
continue
if filename in warned:
continue
if len(getattr(mod, "__path__", ())) > 1:
if len(getattr(mod, '__path__', ())) > 1:
# A namespace package, which confuses this code, so ignore it.
continue
@ -477,10 +486,10 @@ class InOrOut:
# of tracing anyway.
continue
if disp.trace:
msg = f"Already imported a file that will be measured: {filename}"
self.warn(msg, slug="already-imported")
msg = f'Already imported a file that will be measured: {filename}'
self.warn(msg, slug='already-imported')
warned.add(filename)
elif self.debug and self.debug.should("trace"):
elif self.debug and self.debug.should('trace'):
self.debug.write(
"Didn't trace already imported file {!r}: {}".format(
disp.original_filename, disp.reason,
@ -500,7 +509,7 @@ class InOrOut:
"""
mod = sys.modules.get(pkg)
if mod is None:
self.warn(f"Module {pkg} was never imported.", slug="module-not-imported")
self.warn(f'Module {pkg} was never imported.', slug='module-not-imported')
return
if module_is_namespace(mod):
@ -509,14 +518,14 @@ class InOrOut:
return
if not module_has_file(mod):
self.warn(f"Module {pkg} has no Python source.", slug="module-not-python")
self.warn(f'Module {pkg} has no Python source.', slug='module-not-python')
return
# The module was in sys.modules, and seems like a module with code, but
# we never measured it. I guess that means it was imported before
# coverage even started.
msg = f"Module {pkg} was previously imported, but not measured"
self.warn(msg, slug="module-not-measured")
msg = f'Module {pkg} was previously imported, but not measured'
self.warn(msg, slug='module-not-measured')
def find_possibly_unexecuted_files(self) -> Iterable[tuple[str, str | None]]:
"""Find files in the areas of interest that might be untraced.
@ -524,8 +533,10 @@ class InOrOut:
Yields pairs: file path, and responsible plug-in name.
"""
for pkg in self.source_pkgs:
if (pkg not in sys.modules or
not module_has_file(sys.modules[pkg])):
if (
pkg not in sys.modules or
not module_has_file(sys.modules[pkg])
):
continue
pkg_file = source_for_file(cast(str, sys.modules[pkg].__file__))
yield from self._find_executable_files(canonical_path(pkg_file))
@ -569,16 +580,16 @@ class InOrOut:
Returns a list of (key, value) pairs.
"""
info = [
("coverage_paths", self.cover_paths),
("stdlib_paths", self.pylib_paths),
("third_party_paths", self.third_paths),
("source_in_third_party_paths", self.source_in_third_paths),
('coverage_paths', self.cover_paths),
('stdlib_paths', self.pylib_paths),
('third_party_paths', self.third_paths),
('source_in_third_party_paths', self.source_in_third_paths),
]
matcher_names = [
"source_match", "source_pkgs_match",
"include_match", "omit_match",
"cover_match", "pylib_match", "third_match", "source_in_third_match",
'source_match', 'source_pkgs_match',
'include_match', 'omit_match',
'cover_match', 'pylib_match', 'third_match', 'source_in_third_match',
]
for matcher_name in matcher_names:
@ -586,7 +597,7 @@ class InOrOut:
if matcher:
matcher_info = matcher.info()
else:
matcher_info = "-none-"
matcher_info = '-none-'
info.append((matcher_name, matcher_info))
return info

View file

@ -1,20 +1,22 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Json reporting for coverage.py"""
from __future__ import annotations
import datetime
import json
import sys
from typing import Any, IO, Iterable, TYPE_CHECKING
from typing import Any
from typing import IO
from typing import Iterable
from typing import TYPE_CHECKING
from coverage import __version__
from coverage.report_core import get_analysis_to_report
from coverage.results import Analysis, Numbers
from coverage.types import TMorf, TLineNo
from coverage.results import Analysis
from coverage.results import Numbers
from coverage.types import TLineNo
from coverage.types import TMorf
if TYPE_CHECKING:
from coverage import Coverage
@ -25,10 +27,11 @@ if TYPE_CHECKING:
# 2: add the meta.format field.
FORMAT_VERSION = 2
class JsonReporter:
"""A reporter for writing JSON coverage results."""
report_type = "JSON report"
report_type = 'JSON report'
def __init__(self, coverage: Coverage) -> None:
self.coverage = coverage
@ -47,12 +50,12 @@ class JsonReporter:
outfile = outfile or sys.stdout
coverage_data = self.coverage.get_data()
coverage_data.set_query_contexts(self.config.report_contexts)
self.report_data["meta"] = {
"format": FORMAT_VERSION,
"version": __version__,
"timestamp": datetime.datetime.now().isoformat(),
"branch_coverage": coverage_data.has_arcs(),
"show_contexts": self.config.json_show_contexts,
self.report_data['meta'] = {
'format': FORMAT_VERSION,
'version': __version__,
'timestamp': datetime.datetime.now().isoformat(),
'branch_coverage': coverage_data.has_arcs(),
'show_contexts': self.config.json_show_contexts,
}
measured_files = {}
@ -62,23 +65,23 @@ class JsonReporter:
analysis,
)
self.report_data["files"] = measured_files
self.report_data['files'] = measured_files
self.report_data["totals"] = {
"covered_lines": self.total.n_executed,
"num_statements": self.total.n_statements,
"percent_covered": self.total.pc_covered,
"percent_covered_display": self.total.pc_covered_str,
"missing_lines": self.total.n_missing,
"excluded_lines": self.total.n_excluded,
self.report_data['totals'] = {
'covered_lines': self.total.n_executed,
'num_statements': self.total.n_statements,
'percent_covered': self.total.pc_covered,
'percent_covered_display': self.total.pc_covered_str,
'missing_lines': self.total.n_missing,
'excluded_lines': self.total.n_excluded,
}
if coverage_data.has_arcs():
self.report_data["totals"].update({
"num_branches": self.total.n_branches,
"num_partial_branches": self.total.n_partial_branches,
"covered_branches": self.total.n_executed_branches,
"missing_branches": self.total.n_missing_branches,
self.report_data['totals'].update({
'num_branches': self.total.n_branches,
'num_partial_branches': self.total.n_partial_branches,
'covered_branches': self.total.n_executed_branches,
'missing_branches': self.total.n_missing_branches,
})
json.dump(
@ -94,32 +97,32 @@ class JsonReporter:
nums = analysis.numbers
self.total += nums
summary = {
"covered_lines": nums.n_executed,
"num_statements": nums.n_statements,
"percent_covered": nums.pc_covered,
"percent_covered_display": nums.pc_covered_str,
"missing_lines": nums.n_missing,
"excluded_lines": nums.n_excluded,
'covered_lines': nums.n_executed,
'num_statements': nums.n_statements,
'percent_covered': nums.pc_covered,
'percent_covered_display': nums.pc_covered_str,
'missing_lines': nums.n_missing,
'excluded_lines': nums.n_excluded,
}
reported_file = {
"executed_lines": sorted(analysis.executed),
"summary": summary,
"missing_lines": sorted(analysis.missing),
"excluded_lines": sorted(analysis.excluded),
'executed_lines': sorted(analysis.executed),
'summary': summary,
'missing_lines': sorted(analysis.missing),
'excluded_lines': sorted(analysis.excluded),
}
if self.config.json_show_contexts:
reported_file["contexts"] = analysis.data.contexts_by_lineno(analysis.filename)
reported_file['contexts'] = analysis.data.contexts_by_lineno(analysis.filename)
if coverage_data.has_arcs():
summary.update({
"num_branches": nums.n_branches,
"num_partial_branches": nums.n_partial_branches,
"covered_branches": nums.n_executed_branches,
"missing_branches": nums.n_missing_branches,
'num_branches': nums.n_branches,
'num_partial_branches': nums.n_partial_branches,
'covered_branches': nums.n_executed_branches,
'missing_branches': nums.n_missing_branches,
})
reported_file["executed_branches"] = list(
reported_file['executed_branches'] = list(
_convert_branch_arcs(analysis.executed_branch_arcs()),
)
reported_file["missing_branches"] = list(
reported_file['missing_branches'] = list(
_convert_branch_arcs(analysis.missing_branch_arcs()),
)
return reported_file

Some files were not shown because too many files have changed in this diff Show more