[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2024-04-13 00:00:18 +00:00
parent 72ad6dc953
commit f4cd1ba0d6
813 changed files with 66015 additions and 58839 deletions

View file

@ -39,4 +39,4 @@ repos:
rev: v1.9.0 rev: v1.9.0
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: [types-all] additional_dependencies: [types-all]

View file

@ -44,7 +44,7 @@ command:
PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
For more information on Execution Policies: For more information on Execution Policies:
https://go.microsoft.com/fwlink/?LinkID=135170 https://go.microsoft.com/fwlink/?LinkID=135170
#> #>

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3 #!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*- from __future__ import annotations
import re import re
import sys import sys
from coverage.cmdline import main from coverage.cmdline import main
if __name__ == '__main__': if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3 #!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*- from __future__ import annotations
import re import re
import sys import sys
from coverage.cmdline import main from coverage.cmdline import main
if __name__ == '__main__': if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3 #!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*- from __future__ import annotations
import re import re
import sys import sys
from coverage.cmdline import main from coverage.cmdline import main
if __name__ == '__main__': if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3 #!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*- from __future__ import annotations
import re import re
import sys import sys
from pip._internal.cli.main import main from pip._internal.cli.main import main
if __name__ == '__main__': if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3 #!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*- from __future__ import annotations
import re import re
import sys import sys
from pip._internal.cli.main import main from pip._internal.cli.main import main
if __name__ == '__main__': if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3 #!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*- from __future__ import annotations
import re import re
import sys import sys
from pip._internal.cli.main import main from pip._internal.cli.main import main
if __name__ == '__main__': if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3 #!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*- from __future__ import annotations
import re import re
import sys import sys
from pytest import console_main from pytest import console_main
if __name__ == '__main__': if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,7 +1,9 @@
#!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3 #!/Users/admin/Git_repos/pre-commit-hooks/.venv/bin/python3
# -*- coding: utf-8 -*- from __future__ import annotations
import re import re
import sys import sys
from pytest import console_main from pytest import console_main
if __name__ == '__main__': if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])

View file

@ -1,16 +1,20 @@
import sys from __future__ import annotations
import importlib
import os import os
import re import re
import importlib import sys
import warnings import warnings
is_pypy = '__pypy__' in sys.builtin_module_names is_pypy = '__pypy__' in sys.builtin_module_names
warnings.filterwarnings('ignore', warnings.filterwarnings(
r'.+ distutils\b.+ deprecated', 'ignore',
DeprecationWarning) r'.+ distutils\b.+ deprecated',
DeprecationWarning,
)
def warn_distutils_present(): def warn_distutils_present():
@ -21,18 +25,19 @@ def warn_distutils_present():
# https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250 # https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250
return return
warnings.warn( warnings.warn(
"Distutils was imported before Setuptools, but importing Setuptools " 'Distutils was imported before Setuptools, but importing Setuptools '
"also replaces the `distutils` module in `sys.modules`. This may lead " 'also replaces the `distutils` module in `sys.modules`. This may lead '
"to undesirable behaviors or errors. To avoid these issues, avoid " 'to undesirable behaviors or errors. To avoid these issues, avoid '
"using distutils directly, ensure that setuptools is installed in the " 'using distutils directly, ensure that setuptools is installed in the '
"traditional way (e.g. not an editable install), and/or make sure " 'traditional way (e.g. not an editable install), and/or make sure '
"that setuptools is always imported before distutils.") 'that setuptools is always imported before distutils.',
)
def clear_distutils(): def clear_distutils():
if 'distutils' not in sys.modules: if 'distutils' not in sys.modules:
return return
warnings.warn("Setuptools is replacing distutils.") warnings.warn('Setuptools is replacing distutils.')
mods = [name for name in sys.modules if re.match(r'distutils\b', name)] mods = [name for name in sys.modules if re.match(r'distutils\b', name)]
for name in mods: for name in mods:
del sys.modules[name] del sys.modules[name]
@ -74,7 +79,7 @@ class DistutilsMetaFinder:
if path is not None: if path is not None:
return return
method_name = 'spec_for_{fullname}'.format(**locals()) method_name = f'spec_for_{fullname}'
method = getattr(self, method_name, lambda: None) method = getattr(self, method_name, lambda: None)
return method() return method()

View file

@ -1 +1,2 @@
from __future__ import annotations
__import__('_distutils_hack').do_override() __import__('_distutils_hack').do_override()

View file

@ -1,4 +1,5 @@
__all__ = ["__version__", "version_tuple"] from __future__ import annotations
__all__ = ['__version__', 'version_tuple']
try: try:
from ._version import version as __version__ from ._version import version as __version__
@ -6,5 +7,5 @@ try:
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
# broken installation, we don't even try # broken installation, we don't even try
# unknown only works because we do poor mans version compare # unknown only works because we do poor mans version compare
__version__ = "unknown" __version__ = 'unknown'
version_tuple = (0, 0, "unknown") # type:ignore[assignment] version_tuple = (0, 0, 'unknown') # type:ignore[assignment]

View file

@ -61,11 +61,12 @@ If things do not work right away:
which should throw a KeyError: 'COMPLINE' (which is properly set by the which should throw a KeyError: 'COMPLINE' (which is properly set by the
global argcomplete script). global argcomplete script).
""" """
from __future__ import annotations
import argparse import argparse
from glob import glob
import os import os
import sys import sys
from glob import glob
from typing import Any from typing import Any
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -77,7 +78,7 @@ class FastFilesCompleter:
def __init__(self, directories: bool = True) -> None: def __init__(self, directories: bool = True) -> None:
self.directories = directories self.directories = directories
def __call__(self, prefix: str, **kwargs: Any) -> List[str]: def __call__(self, prefix: str, **kwargs: Any) -> list[str]:
# Only called on non option completions. # Only called on non option completions.
if os.sep in prefix[1:]: if os.sep in prefix[1:]:
prefix_dir = len(os.path.dirname(prefix) + os.sep) prefix_dir = len(os.path.dirname(prefix) + os.sep)
@ -85,26 +86,26 @@ class FastFilesCompleter:
prefix_dir = 0 prefix_dir = 0
completion = [] completion = []
globbed = [] globbed = []
if "*" not in prefix and "?" not in prefix: if '*' not in prefix and '?' not in prefix:
# We are on unix, otherwise no bash. # We are on unix, otherwise no bash.
if not prefix or prefix[-1] == os.sep: if not prefix or prefix[-1] == os.sep:
globbed.extend(glob(prefix + ".*")) globbed.extend(glob(prefix + '.*'))
prefix += "*" prefix += '*'
globbed.extend(glob(prefix)) globbed.extend(glob(prefix))
for x in sorted(globbed): for x in sorted(globbed):
if os.path.isdir(x): if os.path.isdir(x):
x += "/" x += '/'
# Append stripping the prefix (like bash, not like compgen). # Append stripping the prefix (like bash, not like compgen).
completion.append(x[prefix_dir:]) completion.append(x[prefix_dir:])
return completion return completion
if os.environ.get("_ARGCOMPLETE"): if os.environ.get('_ARGCOMPLETE'):
try: try:
import argcomplete.completers import argcomplete.completers
except ImportError: except ImportError:
sys.exit(-1) sys.exit(-1)
filescompleter: Optional[FastFilesCompleter] = FastFilesCompleter() filescompleter: FastFilesCompleter | None = FastFilesCompleter()
def try_argcomplete(parser: argparse.ArgumentParser) -> None: def try_argcomplete(parser: argparse.ArgumentParser) -> None:
argcomplete.autocomplete(parser, always_complete_options=False) argcomplete.autocomplete(parser, always_complete_options=False)

View file

@ -1,4 +1,5 @@
"""Python inspection/code generation API.""" """Python inspection/code generation API."""
from __future__ import annotations
from .code import Code from .code import Code
from .code import ExceptionInfo from .code import ExceptionInfo
@ -12,13 +13,13 @@ from .source import Source
__all__ = [ __all__ = [
"Code", 'Code',
"ExceptionInfo", 'ExceptionInfo',
"filter_traceback", 'filter_traceback',
"Frame", 'Frame',
"getfslineno", 'getfslineno',
"getrawcode", 'getrawcode',
"Traceback", 'Traceback',
"TracebackEntry", 'TracebackEntry',
"Source", 'Source',
] ]

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,13 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import ast import ast
from bisect import bisect_right
import inspect import inspect
import textwrap import textwrap
import tokenize import tokenize
import types import types
import warnings
from bisect import bisect_right
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List from typing import List
@ -12,7 +15,6 @@ from typing import Optional
from typing import overload from typing import overload
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
import warnings
class Source: class Source:
@ -23,20 +25,20 @@ class Source:
def __init__(self, obj: object = None) -> None: def __init__(self, obj: object = None) -> None:
if not obj: if not obj:
self.lines: List[str] = [] self.lines: list[str] = []
elif isinstance(obj, Source): elif isinstance(obj, Source):
self.lines = obj.lines self.lines = obj.lines
elif isinstance(obj, (tuple, list)): elif isinstance(obj, (tuple, list)):
self.lines = deindent(x.rstrip("\n") for x in obj) self.lines = deindent(x.rstrip('\n') for x in obj)
elif isinstance(obj, str): elif isinstance(obj, str):
self.lines = deindent(obj.split("\n")) self.lines = deindent(obj.split('\n'))
else: else:
try: try:
rawcode = getrawcode(obj) rawcode = getrawcode(obj)
src = inspect.getsource(rawcode) src = inspect.getsource(rawcode)
except TypeError: except TypeError:
src = inspect.getsource(obj) # type: ignore[arg-type] src = inspect.getsource(obj) # type: ignore[arg-type]
self.lines = deindent(src.split("\n")) self.lines = deindent(src.split('\n'))
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, Source): if not isinstance(other, Source):
@ -51,17 +53,17 @@ class Source:
... ...
@overload @overload
def __getitem__(self, key: slice) -> "Source": def __getitem__(self, key: slice) -> Source:
... ...
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: def __getitem__(self, key: int | slice) -> str | Source:
if isinstance(key, int): if isinstance(key, int):
return self.lines[key] return self.lines[key]
else: else:
if key.step not in (None, 1): if key.step not in (None, 1):
raise IndexError("cannot slice a Source with a step") raise IndexError('cannot slice a Source with a step')
newsource = Source() newsource = Source()
newsource.lines = self.lines[key.start : key.stop] newsource.lines = self.lines[key.start: key.stop]
return newsource return newsource
def __iter__(self) -> Iterator[str]: def __iter__(self) -> Iterator[str]:
@ -70,7 +72,7 @@ class Source:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.lines) return len(self.lines)
def strip(self) -> "Source": def strip(self) -> Source:
"""Return new Source object with trailing and leading blank lines removed.""" """Return new Source object with trailing and leading blank lines removed."""
start, end = 0, len(self) start, end = 0, len(self)
while start < end and not self.lines[start].strip(): while start < end and not self.lines[start].strip():
@ -81,35 +83,35 @@ class Source:
source.lines[:] = self.lines[start:end] source.lines[:] = self.lines[start:end]
return source return source
def indent(self, indent: str = " " * 4) -> "Source": def indent(self, indent: str = ' ' * 4) -> Source:
"""Return a copy of the source object with all lines indented by the """Return a copy of the source object with all lines indented by the
given indent-string.""" given indent-string."""
newsource = Source() newsource = Source()
newsource.lines = [(indent + line) for line in self.lines] newsource.lines = [(indent + line) for line in self.lines]
return newsource return newsource
def getstatement(self, lineno: int) -> "Source": def getstatement(self, lineno: int) -> Source:
"""Return Source statement which contains the given linenumber """Return Source statement which contains the given linenumber
(counted from 0).""" (counted from 0)."""
start, end = self.getstatementrange(lineno) start, end = self.getstatementrange(lineno)
return self[start:end] return self[start:end]
def getstatementrange(self, lineno: int) -> Tuple[int, int]: def getstatementrange(self, lineno: int) -> tuple[int, int]:
"""Return (start, end) tuple which spans the minimal statement region """Return (start, end) tuple which spans the minimal statement region
which containing the given lineno.""" which containing the given lineno."""
if not (0 <= lineno < len(self)): if not (0 <= lineno < len(self)):
raise IndexError("lineno out of range") raise IndexError('lineno out of range')
ast, start, end = getstatementrange_ast(lineno, self) ast, start, end = getstatementrange_ast(lineno, self)
return start, end return start, end
def deindent(self) -> "Source": def deindent(self) -> Source:
"""Return a new Source object deindented.""" """Return a new Source object deindented."""
newsource = Source() newsource = Source()
newsource.lines[:] = deindent(self.lines) newsource.lines[:] = deindent(self.lines)
return newsource return newsource
def __str__(self) -> str: def __str__(self) -> str:
return "\n".join(self.lines) return '\n'.join(self.lines)
# #
@ -117,7 +119,7 @@ class Source:
# #
def findsource(obj) -> Tuple[Optional[Source], int]: def findsource(obj) -> tuple[Source | None, int]:
try: try:
sourcelines, lineno = inspect.findsource(obj) sourcelines, lineno = inspect.findsource(obj)
except Exception: except Exception:
@ -134,20 +136,20 @@ def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
except AttributeError: except AttributeError:
pass pass
if trycall: if trycall:
call = getattr(obj, "__call__", None) call = getattr(obj, '__call__', None)
if call and not isinstance(obj, type): if call and not isinstance(obj, type):
return getrawcode(call, trycall=False) return getrawcode(call, trycall=False)
raise TypeError(f"could not get code object for {obj!r}") raise TypeError(f'could not get code object for {obj!r}')
def deindent(lines: Iterable[str]) -> List[str]: def deindent(lines: Iterable[str]) -> list[str]:
return textwrap.dedent("\n".join(lines)).splitlines() return textwrap.dedent('\n'.join(lines)).splitlines()
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]: def get_statement_startend2(lineno: int, node: ast.AST) -> tuple[int, int | None]:
# Flatten all statements and except handlers into one lineno-list. # Flatten all statements and except handlers into one lineno-list.
# AST's line numbers start indexing at 1. # AST's line numbers start indexing at 1.
values: List[int] = [] values: list[int] = []
for x in ast.walk(node): for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)): if isinstance(x, (ast.stmt, ast.ExceptHandler)):
# The lineno points to the class/def, so need to include the decorators. # The lineno points to the class/def, so need to include the decorators.
@ -155,8 +157,8 @@ def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[i
for d in x.decorator_list: for d in x.decorator_list:
values.append(d.lineno - 1) values.append(d.lineno - 1)
values.append(x.lineno - 1) values.append(x.lineno - 1)
for name in ("finalbody", "orelse"): for name in ('finalbody', 'orelse'):
val: Optional[List[ast.stmt]] = getattr(x, name, None) val: list[ast.stmt] | None = getattr(x, name, None)
if val: if val:
# Treat the finally/orelse part as its own statement. # Treat the finally/orelse part as its own statement.
values.append(val[0].lineno - 1 - 1) values.append(val[0].lineno - 1 - 1)
@ -174,15 +176,15 @@ def getstatementrange_ast(
lineno: int, lineno: int,
source: Source, source: Source,
assertion: bool = False, assertion: bool = False,
astnode: Optional[ast.AST] = None, astnode: ast.AST | None = None,
) -> Tuple[ast.AST, int, int]: ) -> tuple[ast.AST, int, int]:
if astnode is None: if astnode is None:
content = str(source) content = str(source)
# See #4260: # See #4260:
# Don't produce duplicate warnings when compiling source to find AST. # Don't produce duplicate warnings when compiling source to find AST.
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter('ignore')
astnode = ast.parse(content, "source", "exec") astnode = ast.parse(content, 'source', 'exec')
start, end = get_statement_startend2(lineno, astnode) start, end = get_statement_startend2(lineno, astnode)
# We need to correct the end: # We need to correct the end:
@ -200,7 +202,7 @@ def getstatementrange_ast(
block_finder.started = ( block_finder.started = (
bool(source.lines[start]) and source.lines[start][0].isspace() bool(source.lines[start]) and source.lines[start][0].isspace()
) )
it = ((x + "\n") for x in source.lines[start:end]) it = ((x + '\n') for x in source.lines[start:end])
try: try:
for tok in tokenize.generate_tokens(lambda: next(it)): for tok in tokenize.generate_tokens(lambda: next(it)):
block_finder.tokeneater(*tok) block_finder.tokeneater(*tok)
@ -212,7 +214,7 @@ def getstatementrange_ast(
# The end might still point to a comment or empty line, correct it. # The end might still point to a comment or empty line, correct it.
while end: while end:
line = source.lines[end - 1].lstrip() line = source.lines[end - 1].lstrip()
if line.startswith("#") or not line: if line.startswith('#') or not line:
end -= 1 end -= 1
else: else:
break break

View file

@ -1,8 +1,10 @@
from __future__ import annotations
from .terminalwriter import get_terminal_width from .terminalwriter import get_terminal_width
from .terminalwriter import TerminalWriter from .terminalwriter import TerminalWriter
__all__ = [ __all__ = [
"TerminalWriter", 'TerminalWriter',
"get_terminal_width", 'get_terminal_width',
] ]

View file

@ -13,11 +13,13 @@
# tuples with fairly non-descriptive content. This is modeled very much # tuples with fairly non-descriptive content. This is modeled very much
# after Lisp/Scheme - style pretty-printing of lists. If you find it # after Lisp/Scheme - style pretty-printing of lists. If you find it
# useful, thank small children who sleep at night. # useful, thank small children who sleep at night.
from __future__ import annotations
import collections as _collections import collections as _collections
import dataclasses as _dataclasses import dataclasses as _dataclasses
from io import StringIO as _StringIO
import re import re
import types as _types import types as _types
from io import StringIO as _StringIO
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
@ -39,7 +41,7 @@ class _safe_key:
""" """
__slots__ = ["obj"] __slots__ = ['obj']
def __init__(self, obj): def __init__(self, obj):
self.obj = obj self.obj = obj
@ -64,7 +66,7 @@ class PrettyPrinter:
self, self,
indent: int = 4, indent: int = 4,
width: int = 80, width: int = 80,
depth: Optional[int] = None, depth: int | None = None,
) -> None: ) -> None:
"""Handle pretty printing operations onto a stream using a set of """Handle pretty printing operations onto a stream using a set of
configured parameters. configured parameters.
@ -80,11 +82,11 @@ class PrettyPrinter:
""" """
if indent < 0: if indent < 0:
raise ValueError("indent must be >= 0") raise ValueError('indent must be >= 0')
if depth is not None and depth <= 0: if depth is not None and depth <= 0:
raise ValueError("depth must be > 0") raise ValueError('depth must be > 0')
if not width: if not width:
raise ValueError("width must be != 0") raise ValueError('width must be != 0')
self._depth = depth self._depth = depth
self._indent_per_level = indent self._indent_per_level = indent
self._width = width self._width = width
@ -100,7 +102,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
objid = id(object) objid = id(object)
@ -114,17 +116,16 @@ class PrettyPrinter:
p(self, object, stream, indent, allowance, context, level + 1) p(self, object, stream, indent, allowance, context, level + 1)
context.remove(objid) context.remove(objid)
elif ( elif (
_dataclasses.is_dataclass(object) _dataclasses.is_dataclass(object) and
and not isinstance(object, type) not isinstance(object, type) and
and object.__dataclass_params__.repr object.__dataclass_params__.repr and
and
# Check dataclass has generated repr method. # Check dataclass has generated repr method.
hasattr(object.__repr__, "__wrapped__") hasattr(object.__repr__, '__wrapped__') and
and "__create_fn__" in object.__repr__.__wrapped__.__qualname__ '__create_fn__' in object.__repr__.__wrapped__.__qualname__
): ):
context.add(objid) context.add(objid)
self._pprint_dataclass( self._pprint_dataclass(
object, stream, indent, allowance, context, level + 1 object, stream, indent, allowance, context, level + 1,
) )
context.remove(objid) context.remove(objid)
else: else:
@ -136,7 +137,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
cls_name = object.__class__.__name__ cls_name = object.__class__.__name__
@ -145,13 +146,13 @@ class PrettyPrinter:
for f in _dataclasses.fields(object) for f in _dataclasses.fields(object)
if f.repr if f.repr
] ]
stream.write(cls_name + "(") stream.write(cls_name + '(')
self._format_namespace_items(items, stream, indent, allowance, context, level) self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")") stream.write(')')
_dispatch: Dict[ _dispatch: dict[
Callable[..., str], Callable[..., str],
Callable[["PrettyPrinter", Any, IO[str], int, int, Set[int], int], None], Callable[[PrettyPrinter, Any, IO[str], int, int, set[int], int], None],
] = {} ] = {}
def _pprint_dict( def _pprint_dict(
@ -160,14 +161,14 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
write = stream.write write = stream.write
write("{") write('{')
items = sorted(object.items(), key=_safe_tuple) items = sorted(object.items(), key=_safe_tuple)
self._format_dict_items(items, stream, indent, allowance, context, level) self._format_dict_items(items, stream, indent, allowance, context, level)
write("}") write('}')
_dispatch[dict.__repr__] = _pprint_dict _dispatch[dict.__repr__] = _pprint_dict
@ -177,16 +178,16 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not len(object): if not len(object):
stream.write(repr(object)) stream.write(repr(object))
return return
cls = object.__class__ cls = object.__class__
stream.write(cls.__name__ + "(") stream.write(cls.__name__ + '(')
self._pprint_dict(object, stream, indent, allowance, context, level) self._pprint_dict(object, stream, indent, allowance, context, level)
stream.write(")") stream.write(')')
_dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict _dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict
@ -196,12 +197,12 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
stream.write("[") stream.write('[')
self._format_items(object, stream, indent, allowance, context, level) self._format_items(object, stream, indent, allowance, context, level)
stream.write("]") stream.write(']')
_dispatch[list.__repr__] = _pprint_list _dispatch[list.__repr__] = _pprint_list
@ -211,12 +212,12 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
stream.write("(") stream.write('(')
self._format_items(object, stream, indent, allowance, context, level) self._format_items(object, stream, indent, allowance, context, level)
stream.write(")") stream.write(')')
_dispatch[tuple.__repr__] = _pprint_tuple _dispatch[tuple.__repr__] = _pprint_tuple
@ -226,7 +227,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not len(object): if not len(object):
@ -234,11 +235,11 @@ class PrettyPrinter:
return return
typ = object.__class__ typ = object.__class__
if typ is set: if typ is set:
stream.write("{") stream.write('{')
endchar = "}" endchar = '}'
else: else:
stream.write(typ.__name__ + "({") stream.write(typ.__name__ + '({')
endchar = "})" endchar = '})'
object = sorted(object, key=_safe_key) object = sorted(object, key=_safe_key)
self._format_items(object, stream, indent, allowance, context, level) self._format_items(object, stream, indent, allowance, context, level)
stream.write(endchar) stream.write(endchar)
@ -252,7 +253,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
write = stream.write write = stream.write
@ -273,12 +274,12 @@ class PrettyPrinter:
chunks.append(rep) chunks.append(rep)
else: else:
# A list of alternating (non-space, space) strings # A list of alternating (non-space, space) strings
parts = re.findall(r"\S*\s*", line) parts = re.findall(r'\S*\s*', line)
assert parts assert parts
assert not parts[-1] assert not parts[-1]
parts.pop() # drop empty last part parts.pop() # drop empty last part
max_width2 = max_width max_width2 = max_width
current = "" current = ''
for j, part in enumerate(parts): for j, part in enumerate(parts):
candidate = current + part candidate = current + part
if j == len(parts) - 1 and i == len(lines) - 1: if j == len(parts) - 1 and i == len(lines) - 1:
@ -295,13 +296,13 @@ class PrettyPrinter:
write(rep) write(rep)
return return
if level == 1: if level == 1:
write("(") write('(')
for i, rep in enumerate(chunks): for i, rep in enumerate(chunks):
if i > 0: if i > 0:
write("\n" + " " * indent) write('\n' + ' ' * indent)
write(rep) write(rep)
if level == 1: if level == 1:
write(")") write(')')
_dispatch[str.__repr__] = _pprint_str _dispatch[str.__repr__] = _pprint_str
@ -311,7 +312,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
write = stream.write write = stream.write
@ -322,15 +323,15 @@ class PrettyPrinter:
if parens: if parens:
indent += 1 indent += 1
allowance += 1 allowance += 1
write("(") write('(')
delim = "" delim = ''
for rep in _wrap_bytes_repr(object, self._width - indent, allowance): for rep in _wrap_bytes_repr(object, self._width - indent, allowance):
write(delim) write(delim)
write(rep) write(rep)
if not delim: if not delim:
delim = "\n" + " " * indent delim = '\n' + ' ' * indent
if parens: if parens:
write(")") write(')')
_dispatch[bytes.__repr__] = _pprint_bytes _dispatch[bytes.__repr__] = _pprint_bytes
@ -340,15 +341,15 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
write = stream.write write = stream.write
write("bytearray(") write('bytearray(')
self._pprint_bytes( self._pprint_bytes(
bytes(object), stream, indent + 10, allowance + 1, context, level + 1 bytes(object), stream, indent + 10, allowance + 1, context, level + 1,
) )
write(")") write(')')
_dispatch[bytearray.__repr__] = _pprint_bytearray _dispatch[bytearray.__repr__] = _pprint_bytearray
@ -358,12 +359,12 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
stream.write("mappingproxy(") stream.write('mappingproxy(')
self._format(object.copy(), stream, indent, allowance, context, level) self._format(object.copy(), stream, indent, allowance, context, level)
stream.write(")") stream.write(')')
_dispatch[_types.MappingProxyType.__repr__] = _pprint_mappingproxy _dispatch[_types.MappingProxyType.__repr__] = _pprint_mappingproxy
@ -373,29 +374,29 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if type(object) is _types.SimpleNamespace: if type(object) is _types.SimpleNamespace:
# The SimpleNamespace repr is "namespace" instead of the class # The SimpleNamespace repr is "namespace" instead of the class
# name, so we do the same here. For subclasses; use the class name. # name, so we do the same here. For subclasses; use the class name.
cls_name = "namespace" cls_name = 'namespace'
else: else:
cls_name = object.__class__.__name__ cls_name = object.__class__.__name__
items = object.__dict__.items() items = object.__dict__.items()
stream.write(cls_name + "(") stream.write(cls_name + '(')
self._format_namespace_items(items, stream, indent, allowance, context, level) self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")") stream.write(')')
_dispatch[_types.SimpleNamespace.__repr__] = _pprint_simplenamespace _dispatch[_types.SimpleNamespace.__repr__] = _pprint_simplenamespace
def _format_dict_items( def _format_dict_items(
self, self,
items: List[Tuple[Any, Any]], items: list[tuple[Any, Any]],
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not items: if not items:
@ -403,23 +404,23 @@ class PrettyPrinter:
write = stream.write write = stream.write
item_indent = indent + self._indent_per_level item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent delimnl = '\n' + ' ' * item_indent
for key, ent in items: for key, ent in items:
write(delimnl) write(delimnl)
write(self._repr(key, context, level)) write(self._repr(key, context, level))
write(": ") write(': ')
self._format(ent, stream, item_indent, 1, context, level) self._format(ent, stream, item_indent, 1, context, level)
write(",") write(',')
write("\n" + " " * indent) write('\n' + ' ' * indent)
def _format_namespace_items( def _format_namespace_items(
self, self,
items: List[Tuple[Any, Any]], items: list[tuple[Any, Any]],
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not items: if not items:
@ -427,15 +428,15 @@ class PrettyPrinter:
write = stream.write write = stream.write
item_indent = indent + self._indent_per_level item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent delimnl = '\n' + ' ' * item_indent
for key, ent in items: for key, ent in items:
write(delimnl) write(delimnl)
write(key) write(key)
write("=") write('=')
if id(ent) in context: if id(ent) in context:
# Special-case representation of recursion to match standard # Special-case representation of recursion to match standard
# recursive dataclass repr. # recursive dataclass repr.
write("...") write('...')
else: else:
self._format( self._format(
ent, ent,
@ -446,17 +447,17 @@ class PrettyPrinter:
level, level,
) )
write(",") write(',')
write("\n" + " " * indent) write('\n' + ' ' * indent)
def _format_items( def _format_items(
self, self,
items: List[Any], items: list[Any],
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not items: if not items:
@ -464,16 +465,16 @@ class PrettyPrinter:
write = stream.write write = stream.write
item_indent = indent + self._indent_per_level item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent delimnl = '\n' + ' ' * item_indent
for item in items: for item in items:
write(delimnl) write(delimnl)
self._format(item, stream, item_indent, 1, context, level) self._format(item, stream, item_indent, 1, context, level)
write(",") write(',')
write("\n" + " " * indent) write('\n' + ' ' * indent)
def _repr(self, object: Any, context: Set[int], level: int) -> str: def _repr(self, object: Any, context: set[int], level: int) -> str:
return self._safe_repr(object, context.copy(), self._depth, level) return self._safe_repr(object, context.copy(), self._depth, level)
def _pprint_default_dict( def _pprint_default_dict(
@ -482,13 +483,13 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
rdf = self._repr(object.default_factory, context, level) rdf = self._repr(object.default_factory, context, level)
stream.write(f"{object.__class__.__name__}({rdf}, ") stream.write(f'{object.__class__.__name__}({rdf}, ')
self._pprint_dict(object, stream, indent, allowance, context, level) self._pprint_dict(object, stream, indent, allowance, context, level)
stream.write(")") stream.write(')')
_dispatch[_collections.defaultdict.__repr__] = _pprint_default_dict _dispatch[_collections.defaultdict.__repr__] = _pprint_default_dict
@ -498,18 +499,18 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
stream.write(object.__class__.__name__ + "(") stream.write(object.__class__.__name__ + '(')
if object: if object:
stream.write("{") stream.write('{')
items = object.most_common() items = object.most_common()
self._format_dict_items(items, stream, indent, allowance, context, level) self._format_dict_items(items, stream, indent, allowance, context, level)
stream.write("}") stream.write('}')
stream.write(")") stream.write(')')
_dispatch[_collections.Counter.__repr__] = _pprint_counter _dispatch[_collections.Counter.__repr__] = _pprint_counter
@ -519,16 +520,16 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])): if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])):
stream.write(repr(object)) stream.write(repr(object))
return return
stream.write(object.__class__.__name__ + "(") stream.write(object.__class__.__name__ + '(')
self._format_items(object.maps, stream, indent, allowance, context, level) self._format_items(object.maps, stream, indent, allowance, context, level)
stream.write(")") stream.write(')')
_dispatch[_collections.ChainMap.__repr__] = _pprint_chain_map _dispatch[_collections.ChainMap.__repr__] = _pprint_chain_map
@ -538,16 +539,16 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
stream.write(object.__class__.__name__ + "(") stream.write(object.__class__.__name__ + '(')
if object.maxlen is not None: if object.maxlen is not None:
stream.write("maxlen=%d, " % object.maxlen) stream.write('maxlen=%d, ' % object.maxlen)
stream.write("[") stream.write('[')
self._format_items(object, stream, indent, allowance + 1, context, level) self._format_items(object, stream, indent, allowance + 1, context, level)
stream.write("])") stream.write('])')
_dispatch[_collections.deque.__repr__] = _pprint_deque _dispatch[_collections.deque.__repr__] = _pprint_deque
@ -557,7 +558,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1) self._format(object.data, stream, indent, allowance, context, level - 1)
@ -570,7 +571,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1) self._format(object.data, stream, indent, allowance, context, level - 1)
@ -583,7 +584,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1) self._format(object.data, stream, indent, allowance, context, level - 1)
@ -591,49 +592,49 @@ class PrettyPrinter:
_dispatch[_collections.UserString.__repr__] = _pprint_user_string _dispatch[_collections.UserString.__repr__] = _pprint_user_string
def _safe_repr( def _safe_repr(
self, object: Any, context: Set[int], maxlevels: Optional[int], level: int self, object: Any, context: set[int], maxlevels: int | None, level: int,
) -> str: ) -> str:
typ = type(object) typ = type(object)
if typ in _builtin_scalars: if typ in _builtin_scalars:
return repr(object) return repr(object)
r = getattr(typ, "__repr__", None) r = getattr(typ, '__repr__', None)
if issubclass(typ, dict) and r is dict.__repr__: if issubclass(typ, dict) and r is dict.__repr__:
if not object: if not object:
return "{}" return '{}'
objid = id(object) objid = id(object)
if maxlevels and level >= maxlevels: if maxlevels and level >= maxlevels:
return "{...}" return '{...}'
if objid in context: if objid in context:
return _recursion(object) return _recursion(object)
context.add(objid) context.add(objid)
components: List[str] = [] components: list[str] = []
append = components.append append = components.append
level += 1 level += 1
for k, v in sorted(object.items(), key=_safe_tuple): for k, v in sorted(object.items(), key=_safe_tuple):
krepr = self._safe_repr(k, context, maxlevels, level) krepr = self._safe_repr(k, context, maxlevels, level)
vrepr = self._safe_repr(v, context, maxlevels, level) vrepr = self._safe_repr(v, context, maxlevels, level)
append(f"{krepr}: {vrepr}") append(f'{krepr}: {vrepr}')
context.remove(objid) context.remove(objid)
return "{%s}" % ", ".join(components) return '{%s}' % ', '.join(components)
if (issubclass(typ, list) and r is list.__repr__) or ( if (issubclass(typ, list) and r is list.__repr__) or (
issubclass(typ, tuple) and r is tuple.__repr__ issubclass(typ, tuple) and r is tuple.__repr__
): ):
if issubclass(typ, list): if issubclass(typ, list):
if not object: if not object:
return "[]" return '[]'
format = "[%s]" format = '[%s]'
elif len(object) == 1: elif len(object) == 1:
format = "(%s,)" format = '(%s,)'
else: else:
if not object: if not object:
return "()" return '()'
format = "(%s)" format = '(%s)'
objid = id(object) objid = id(object)
if maxlevels and level >= maxlevels: if maxlevels and level >= maxlevels:
return format % "..." return format % '...'
if objid in context: if objid in context:
return _recursion(object) return _recursion(object)
context.add(objid) context.add(objid)
@ -644,25 +645,25 @@ class PrettyPrinter:
orepr = self._safe_repr(o, context, maxlevels, level) orepr = self._safe_repr(o, context, maxlevels, level)
append(orepr) append(orepr)
context.remove(objid) context.remove(objid)
return format % ", ".join(components) return format % ', '.join(components)
return repr(object) return repr(object)
_builtin_scalars = frozenset( _builtin_scalars = frozenset(
{str, bytes, bytearray, float, complex, bool, type(None), int} {str, bytes, bytearray, float, complex, bool, type(None), int},
) )
def _recursion(object: Any) -> str: def _recursion(object: Any) -> str:
return f"<Recursion on {type(object).__name__} with id={id(object)}>" return f'<Recursion on {type(object).__name__} with id={id(object)}>'
def _wrap_bytes_repr(object: Any, width: int, allowance: int) -> Iterator[str]: def _wrap_bytes_repr(object: Any, width: int, allowance: int) -> Iterator[str]:
current = b"" current = b''
last = len(object) // 4 * 4 last = len(object) // 4 * 4
for i in range(0, len(object), 4): for i in range(0, len(object), 4):
part = object[i : i + 4] part = object[i: i + 4]
candidate = current + part candidate = current + part
if i == last: if i == last:
width -= allowance width -= allowance

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import pprint import pprint
import reprlib import reprlib
from typing import Optional from typing import Optional
@ -18,9 +20,9 @@ def _format_repr_exception(exc: BaseException, obj: object) -> str:
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
raise raise
except BaseException as exc: except BaseException as exc:
exc_info = f"unpresentable exception ({_try_repr_or_str(exc)})" exc_info = f'unpresentable exception ({_try_repr_or_str(exc)})'
return ( return (
f"<[{exc_info} raised in repr()] {type(obj).__name__} object at 0x{id(obj):x}>" f'<[{exc_info} raised in repr()] {type(obj).__name__} object at 0x{id(obj):x}>'
) )
@ -28,7 +30,7 @@ def _ellipsize(s: str, maxsize: int) -> str:
if len(s) > maxsize: if len(s) > maxsize:
i = max(0, (maxsize - 3) // 2) i = max(0, (maxsize - 3) // 2)
j = max(0, maxsize - 3 - i) j = max(0, maxsize - 3 - i)
return s[:i] + "..." + s[len(s) - j :] return s[:i] + '...' + s[len(s) - j:]
return s return s
@ -38,7 +40,7 @@ class SafeRepr(reprlib.Repr):
information on exceptions raised during the call. information on exceptions raised during the call.
""" """
def __init__(self, maxsize: Optional[int], use_ascii: bool = False) -> None: def __init__(self, maxsize: int | None, use_ascii: bool = False) -> None:
""" """
:param maxsize: :param maxsize:
If not None, will truncate the resulting repr to that specific size, using ellipsis If not None, will truncate the resulting repr to that specific size, using ellipsis
@ -97,7 +99,7 @@ DEFAULT_REPR_MAX_SIZE = 240
def saferepr( def saferepr(
obj: object, maxsize: Optional[int] = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False obj: object, maxsize: int | None = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False,
) -> str: ) -> str:
"""Return a size-limited safe repr-string for the given object. """Return a size-limited safe repr-string for the given object.

View file

@ -1,4 +1,5 @@
"""Helper functions for writing to terminals and files.""" """Helper functions for writing to terminals and files."""
from __future__ import annotations
import os import os
import shutil import shutil
@ -26,16 +27,16 @@ def get_terminal_width() -> int:
def should_do_markup(file: TextIO) -> bool: def should_do_markup(file: TextIO) -> bool:
if os.environ.get("PY_COLORS") == "1": if os.environ.get('PY_COLORS') == '1':
return True return True
if os.environ.get("PY_COLORS") == "0": if os.environ.get('PY_COLORS') == '0':
return False return False
if os.environ.get("NO_COLOR"): if os.environ.get('NO_COLOR'):
return False return False
if os.environ.get("FORCE_COLOR"): if os.environ.get('FORCE_COLOR'):
return True return True
return ( return (
hasattr(file, "isatty") and file.isatty() and os.environ.get("TERM") != "dumb" hasattr(file, 'isatty') and file.isatty() and os.environ.get('TERM') != 'dumb'
) )
@ -64,10 +65,10 @@ class TerminalWriter:
invert=7, invert=7,
) )
def __init__(self, file: Optional[TextIO] = None) -> None: def __init__(self, file: TextIO | None = None) -> None:
if file is None: if file is None:
file = sys.stdout file = sys.stdout
if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32": if hasattr(file, 'isatty') and file.isatty() and sys.platform == 'win32':
try: try:
import colorama import colorama
except ImportError: except ImportError:
@ -77,8 +78,8 @@ class TerminalWriter:
assert file is not None assert file is not None
self._file = file self._file = file
self.hasmarkup = should_do_markup(file) self.hasmarkup = should_do_markup(file)
self._current_line = "" self._current_line = ''
self._terminal_width: Optional[int] = None self._terminal_width: int | None = None
self.code_highlight = True self.code_highlight = True
@property @property
@ -99,25 +100,25 @@ class TerminalWriter:
def markup(self, text: str, **markup: bool) -> str: def markup(self, text: str, **markup: bool) -> str:
for name in markup: for name in markup:
if name not in self._esctable: if name not in self._esctable:
raise ValueError(f"unknown markup: {name!r}") raise ValueError(f'unknown markup: {name!r}')
if self.hasmarkup: if self.hasmarkup:
esc = [self._esctable[name] for name, on in markup.items() if on] esc = [self._esctable[name] for name, on in markup.items() if on]
if esc: if esc:
text = "".join("\x1b[%sm" % cod for cod in esc) + text + "\x1b[0m" text = ''.join('\x1b[%sm' % cod for cod in esc) + text + '\x1b[0m'
return text return text
def sep( def sep(
self, self,
sepchar: str, sepchar: str,
title: Optional[str] = None, title: str | None = None,
fullwidth: Optional[int] = None, fullwidth: int | None = None,
**markup: bool, **markup: bool,
) -> None: ) -> None:
if fullwidth is None: if fullwidth is None:
fullwidth = self.fullwidth fullwidth = self.fullwidth
# The goal is to have the line be as long as possible # The goal is to have the line be as long as possible
# under the condition that len(line) <= fullwidth. # under the condition that len(line) <= fullwidth.
if sys.platform == "win32": if sys.platform == 'win32':
# If we print in the last column on windows we are on a # If we print in the last column on windows we are on a
# new line but there is no way to verify/neutralize this # new line but there is no way to verify/neutralize this
# (we may not know the exact line width). # (we may not know the exact line width).
@ -130,7 +131,7 @@ class TerminalWriter:
# N <= (fullwidth - len(title) - 2) // (2*len(sepchar)) # N <= (fullwidth - len(title) - 2) // (2*len(sepchar))
N = max((fullwidth - len(title) - 2) // (2 * len(sepchar)), 1) N = max((fullwidth - len(title) - 2) // (2 * len(sepchar)), 1)
fill = sepchar * N fill = sepchar * N
line = f"{fill} {title} {fill}" line = f'{fill} {title} {fill}'
else: else:
# we want len(sepchar)*N <= fullwidth # we want len(sepchar)*N <= fullwidth
# i.e. N <= fullwidth // len(sepchar) # i.e. N <= fullwidth // len(sepchar)
@ -145,8 +146,8 @@ class TerminalWriter:
def write(self, msg: str, *, flush: bool = False, **markup: bool) -> None: def write(self, msg: str, *, flush: bool = False, **markup: bool) -> None:
if msg: if msg:
current_line = msg.rsplit("\n", 1)[-1] current_line = msg.rsplit('\n', 1)[-1]
if "\n" in msg: if '\n' in msg:
self._current_line = current_line self._current_line = current_line
else: else:
self._current_line += current_line self._current_line += current_line
@ -162,15 +163,15 @@ class TerminalWriter:
# When the Unicode situation improves we should consider # When the Unicode situation improves we should consider
# letting the error propagate instead of masking it (see #7475 # letting the error propagate instead of masking it (see #7475
# for one brief attempt). # for one brief attempt).
msg = msg.encode("unicode-escape").decode("ascii") msg = msg.encode('unicode-escape').decode('ascii')
self._file.write(msg) self._file.write(msg)
if flush: if flush:
self.flush() self.flush()
def line(self, s: str = "", **markup: bool) -> None: def line(self, s: str = '', **markup: bool) -> None:
self.write(s, **markup) self.write(s, **markup)
self.write("\n") self.write('\n')
def flush(self) -> None: def flush(self) -> None:
self._file.flush() self._file.flush()
@ -184,17 +185,17 @@ class TerminalWriter:
""" """
if indents and len(indents) != len(lines): if indents and len(indents) != len(lines):
raise ValueError( raise ValueError(
f"indents size ({len(indents)}) should have same size as lines ({len(lines)})" f'indents size ({len(indents)}) should have same size as lines ({len(lines)})',
) )
if not indents: if not indents:
indents = [""] * len(lines) indents = [''] * len(lines)
source = "\n".join(lines) source = '\n'.join(lines)
new_lines = self._highlight(source).splitlines() new_lines = self._highlight(source).splitlines()
for indent, new_line in zip(indents, new_lines): for indent, new_line in zip(indents, new_lines):
self.line(indent + new_line) self.line(indent + new_line)
def _highlight( def _highlight(
self, source: str, lexer: Literal["diff", "python"] = "python" self, source: str, lexer: Literal['diff', 'python'] = 'python',
) -> str: ) -> str:
"""Highlight the given source if we have markup support.""" """Highlight the given source if we have markup support."""
from _pytest.config.exceptions import UsageError from _pytest.config.exceptions import UsageError
@ -205,9 +206,9 @@ class TerminalWriter:
try: try:
from pygments.formatters.terminal import TerminalFormatter from pygments.formatters.terminal import TerminalFormatter
if lexer == "python": if lexer == 'python':
from pygments.lexers.python import PythonLexer as Lexer from pygments.lexers.python import PythonLexer as Lexer
elif lexer == "diff": elif lexer == 'diff':
from pygments.lexers.diff import DiffLexer as Lexer from pygments.lexers.diff import DiffLexer as Lexer
from pygments import highlight from pygments import highlight
import pygments.util import pygments.util
@ -219,30 +220,30 @@ class TerminalWriter:
source, source,
Lexer(), Lexer(),
TerminalFormatter( TerminalFormatter(
bg=os.getenv("PYTEST_THEME_MODE", "dark"), bg=os.getenv('PYTEST_THEME_MODE', 'dark'),
style=os.getenv("PYTEST_THEME"), style=os.getenv('PYTEST_THEME'),
), ),
) )
# pygments terminal formatter may add a newline when there wasn't one. # pygments terminal formatter may add a newline when there wasn't one.
# We don't want this, remove. # We don't want this, remove.
if highlighted[-1] == "\n" and source[-1] != "\n": if highlighted[-1] == '\n' and source[-1] != '\n':
highlighted = highlighted[:-1] highlighted = highlighted[:-1]
# Some lexers will not set the initial color explicitly # Some lexers will not set the initial color explicitly
# which may lead to the previous color being propagated to the # which may lead to the previous color being propagated to the
# start of the expression, so reset first. # start of the expression, so reset first.
return "\x1b[0m" + highlighted return '\x1b[0m' + highlighted
except pygments.util.ClassNotFound as e: except pygments.util.ClassNotFound as e:
raise UsageError( raise UsageError(
"PYTEST_THEME environment variable had an invalid value: '{}'. " "PYTEST_THEME environment variable had an invalid value: '{}'. "
"Only valid pygment styles are allowed.".format( 'Only valid pygment styles are allowed.'.format(
os.getenv("PYTEST_THEME") os.getenv('PYTEST_THEME'),
) ),
) from e ) from e
except pygments.util.OptionError as e: except pygments.util.OptionError as e:
raise UsageError( raise UsageError(
"PYTEST_THEME_MODE environment variable had an invalid value: '{}'. " "PYTEST_THEME_MODE environment variable had an invalid value: '{}'. "
"The only allowed values are 'dark' and 'light'.".format( "The only allowed values are 'dark' and 'light'.".format(
os.getenv("PYTEST_THEME_MODE") os.getenv('PYTEST_THEME_MODE'),
) ),
) from e ) from e

View file

@ -1,5 +1,7 @@
from functools import lru_cache from __future__ import annotations
import unicodedata import unicodedata
from functools import lru_cache
@lru_cache(100) @lru_cache(100)
@ -17,25 +19,25 @@ def wcwidth(c: str) -> int:
# Some Cf/Zp/Zl characters which should be zero-width. # Some Cf/Zp/Zl characters which should be zero-width.
if ( if (
o == 0x0000 o == 0x0000 or
or 0x200B <= o <= 0x200F 0x200B <= o <= 0x200F or
or 0x2028 <= o <= 0x202E 0x2028 <= o <= 0x202E or
or 0x2060 <= o <= 0x2063 0x2060 <= o <= 0x2063
): ):
return 0 return 0
category = unicodedata.category(c) category = unicodedata.category(c)
# Control characters. # Control characters.
if category == "Cc": if category == 'Cc':
return -1 return -1
# Combining characters with zero width. # Combining characters with zero width.
if category in ("Me", "Mn"): if category in ('Me', 'Mn'):
return 0 return 0
# Full/Wide east asian characters. # Full/Wide east asian characters.
if unicodedata.east_asian_width(c) in ("F", "W"): if unicodedata.east_asian_width(c) in ('F', 'W'):
return 2 return 2
return 1 return 1
@ -47,7 +49,7 @@ def wcswidth(s: str) -> int:
Returns -1 if the string contains non-printable characters. Returns -1 if the string contains non-printable characters.
""" """
width = 0 width = 0
for c in unicodedata.normalize("NFC", s): for c in unicodedata.normalize('NFC', s):
wc = wcwidth(c) wc = wcwidth(c)
if wc < 0: if wc < 0:
return -1 return -1

View file

@ -1,5 +1,4 @@
"""create errno-specific classes for IO or os calls.""" """create errno-specific classes for IO or os calls."""
from __future__ import annotations from __future__ import annotations
import errno import errno
@ -13,25 +12,25 @@ from typing import TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
P = ParamSpec("P") P = ParamSpec('P')
R = TypeVar("R") R = TypeVar('R')
class Error(EnvironmentError): class Error(EnvironmentError):
def __repr__(self) -> str: def __repr__(self) -> str:
return "{}.{} {!r}: {} ".format( return '{}.{} {!r}: {} '.format(
self.__class__.__module__, self.__class__.__module__,
self.__class__.__name__, self.__class__.__name__,
self.__class__.__doc__, self.__class__.__doc__,
" ".join(map(str, self.args)), ' '.join(map(str, self.args)),
# repr(self.args) # repr(self.args)
) )
def __str__(self) -> str: def __str__(self) -> str:
s = "[{}]: {}".format( s = '[{}]: {}'.format(
self.__class__.__doc__, self.__class__.__doc__,
" ".join(map(str, self.args)), ' '.join(map(str, self.args)),
) )
return s return s
@ -58,7 +57,7 @@ class ErrorMaker:
_errno2class: dict[int, type[Error]] = {} _errno2class: dict[int, type[Error]] = {}
def __getattr__(self, name: str) -> type[Error]: def __getattr__(self, name: str) -> type[Error]:
if name[0] == "_": if name[0] == '_':
raise AttributeError(name) raise AttributeError(name)
eno = getattr(errno, name) eno = getattr(errno, name)
cls = self._geterrnoclass(eno) cls = self._geterrnoclass(eno)
@ -69,17 +68,17 @@ class ErrorMaker:
try: try:
return self._errno2class[eno] return self._errno2class[eno]
except KeyError: except KeyError:
clsname = errno.errorcode.get(eno, "UnknownErrno%d" % (eno,)) clsname = errno.errorcode.get(eno, 'UnknownErrno%d' % (eno,))
errorcls = type( errorcls = type(
clsname, clsname,
(Error,), (Error,),
{"__module__": "py.error", "__doc__": os.strerror(eno)}, {'__module__': 'py.error', '__doc__': os.strerror(eno)},
) )
self._errno2class[eno] = errorcls self._errno2class[eno] = errorcls
return errorcls return errorcls
def checked_call( def checked_call(
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs,
) -> R: ) -> R:
"""Call a function and raise an errno-exception if applicable.""" """Call a function and raise an errno-exception if applicable."""
__tracebackhide__ = True __tracebackhide__ = True
@ -88,10 +87,10 @@ class ErrorMaker:
except Error: except Error:
raise raise
except OSError as value: except OSError as value:
if not hasattr(value, "errno"): if not hasattr(value, 'errno'):
raise raise
errno = value.errno errno = value.errno
if sys.platform == "win32": if sys.platform == 'win32':
try: try:
cls = self._geterrnoclass(_winerrnomap[errno]) cls = self._geterrnoclass(_winerrnomap[errno])
except KeyError: except KeyError:
@ -100,7 +99,7 @@ class ErrorMaker:
# we are not on Windows, or we got a proper OSError # we are not on Windows, or we got a proper OSError
cls = self._geterrnoclass(errno) cls = self._geterrnoclass(errno)
raise cls(f"{func.__name__}{args!r}") raise cls(f'{func.__name__}{args!r}')
_error_maker = ErrorMaker() _error_maker = ErrorMaker()

View file

@ -3,11 +3,15 @@
from __future__ import annotations from __future__ import annotations
import atexit import atexit
from contextlib import contextmanager
import fnmatch import fnmatch
import importlib.util import importlib.util
import io import io
import os import os
import posixpath
import sys
import uuid
import warnings
from contextlib import contextmanager
from os.path import abspath from os.path import abspath
from os.path import dirname from os.path import dirname
from os.path import exists from os.path import exists
@ -16,39 +20,35 @@ from os.path import isdir
from os.path import isfile from os.path import isfile
from os.path import islink from os.path import islink
from os.path import normpath from os.path import normpath
import posixpath
from stat import S_ISDIR from stat import S_ISDIR
from stat import S_ISLNK from stat import S_ISLNK
from stat import S_ISREG from stat import S_ISREG
import sys
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import cast from typing import cast
from typing import Literal from typing import Literal
from typing import overload from typing import overload
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import uuid
import warnings
from . import error from . import error
# Moved from local.py. # Moved from local.py.
iswin32 = sys.platform == "win32" or (getattr(os, "_name", False) == "nt") iswin32 = sys.platform == 'win32' or (getattr(os, '_name', False) == 'nt')
class Checkers: class Checkers:
_depend_on_existence = "exists", "link", "dir", "file" _depend_on_existence = 'exists', 'link', 'dir', 'file'
def __init__(self, path): def __init__(self, path):
self.path = path self.path = path
def dotfile(self): def dotfile(self):
return self.path.basename.startswith(".") return self.path.basename.startswith('.')
def ext(self, arg): def ext(self, arg):
if not arg.startswith("."): if not arg.startswith('.'):
arg = "." + arg arg = '.' + arg
return self.path.ext == arg return self.path.ext == arg
def basename(self, arg): def basename(self, arg):
@ -75,14 +75,14 @@ class Checkers:
try: try:
meth = getattr(self, name) meth = getattr(self, name)
except AttributeError: except AttributeError:
if name[:3] == "not": if name[:3] == 'not':
invert = True invert = True
try: try:
meth = getattr(self, name[3:]) meth = getattr(self, name[3:])
except AttributeError: except AttributeError:
pass pass
if meth is None: if meth is None:
raise TypeError(f"no {name!r} checker available for {self.path!r}") raise TypeError(f'no {name!r} checker available for {self.path!r}')
try: try:
if getrawcode(meth).co_argcount > 1: if getrawcode(meth).co_argcount > 1:
if (not meth(value)) ^ invert: if (not meth(value)) ^ invert:
@ -98,7 +98,7 @@ class Checkers:
if name in kw: if name in kw:
if kw.get(name): if kw.get(name):
return False return False
name = "not" + name name = 'not' + name
if name in kw: if name in kw:
if not kw.get(name): if not kw.get(name):
return False return False
@ -140,7 +140,7 @@ class Visitor:
fil = FNMatcher(fil) fil = FNMatcher(fil)
if isinstance(rec, str): if isinstance(rec, str):
self.rec: Callable[[LocalPath], bool] = FNMatcher(rec) self.rec: Callable[[LocalPath], bool] = FNMatcher(rec)
elif not hasattr(rec, "__call__") and rec: elif not hasattr(rec, '__call__') and rec:
self.rec = lambda path: True self.rec = lambda path: True
else: else:
self.rec = rec self.rec = rec
@ -156,7 +156,7 @@ class Visitor:
return return
rec = self.rec rec = self.rec
dirs = self.optsort( dirs = self.optsort(
[p for p in entries if p.check(dir=1) and (rec is None or rec(p))] [p for p in entries if p.check(dir=1) and (rec is None or rec(p))],
) )
if not self.breadthfirst: if not self.breadthfirst:
for subdir in dirs: for subdir in dirs:
@ -179,9 +179,9 @@ class FNMatcher:
pattern = self.pattern pattern = self.pattern
if ( if (
pattern.find(path.sep) == -1 pattern.find(path.sep) == -1 and
and iswin32 iswin32 and
and pattern.find(posixpath.sep) != -1 pattern.find(posixpath.sep) != -1
): ):
# Running on Windows, the pattern has no Windows path separators, # Running on Windows, the pattern has no Windows path separators,
# and the pattern has one or more Posix path separators. Replace # and the pattern has one or more Posix path separators. Replace
@ -193,7 +193,7 @@ class FNMatcher:
else: else:
name = str(path) # path.strpath # XXX svn? name = str(path) # path.strpath # XXX svn?
if not os.path.isabs(pattern): if not os.path.isabs(pattern):
pattern = "*" + path.sep + pattern pattern = '*' + path.sep + pattern
return fnmatch.fnmatch(name, pattern) return fnmatch.fnmatch(name, pattern)
@ -213,7 +213,7 @@ class Stat:
... ...
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
return getattr(self._osstatresult, "st_" + name) return getattr(self._osstatresult, 'st_' + name)
def __init__(self, path, osstatresult): def __init__(self, path, osstatresult):
self.path = path self.path = path
@ -222,7 +222,7 @@ class Stat:
@property @property
def owner(self): def owner(self):
if iswin32: if iswin32:
raise NotImplementedError("XXX win32") raise NotImplementedError('XXX win32')
import pwd import pwd
entry = error.checked_call(pwd.getpwuid, self.uid) # type:ignore[attr-defined] entry = error.checked_call(pwd.getpwuid, self.uid) # type:ignore[attr-defined]
@ -232,7 +232,7 @@ class Stat:
def group(self): def group(self):
"""Return group name of file.""" """Return group name of file."""
if iswin32: if iswin32:
raise NotImplementedError("XXX win32") raise NotImplementedError('XXX win32')
import grp import grp
entry = error.checked_call(grp.getgrgid, self.gid) # type:ignore[attr-defined] entry = error.checked_call(grp.getgrgid, self.gid) # type:ignore[attr-defined]
@ -292,14 +292,14 @@ class LocalPath:
path = os.fspath(path) path = os.fspath(path)
except TypeError: except TypeError:
raise ValueError( raise ValueError(
"can only pass None, Path instances " 'can only pass None, Path instances '
"or non-empty strings to LocalPath" 'or non-empty strings to LocalPath',
) )
if expanduser: if expanduser:
path = os.path.expanduser(path) path = os.path.expanduser(path)
self.strpath = abspath(path) self.strpath = abspath(path)
if sys.platform != "win32": if sys.platform != 'win32':
def chown(self, user, group, rec=0): def chown(self, user, group, rec=0):
"""Change ownership to the given user and group. """Change ownership to the given user and group.
@ -334,7 +334,7 @@ class LocalPath:
relsource = self.__class__(value).relto(base) relsource = self.__class__(value).relto(base)
reldest = self.relto(base) reldest = self.relto(base)
n = reldest.count(self.sep) n = reldest.count(self.sep)
target = self.sep.join(("..",) * n + (relsource,)) target = self.sep.join(('..',) * n + (relsource,))
error.checked_call(os.symlink, target, self.strpath) error.checked_call(os.symlink, target, self.strpath)
def __div__(self, other): def __div__(self, other):
@ -345,34 +345,34 @@ class LocalPath:
@property @property
def basename(self): def basename(self):
"""Basename part of path.""" """Basename part of path."""
return self._getbyspec("basename")[0] return self._getbyspec('basename')[0]
@property @property
def dirname(self): def dirname(self):
"""Dirname part of path.""" """Dirname part of path."""
return self._getbyspec("dirname")[0] return self._getbyspec('dirname')[0]
@property @property
def purebasename(self): def purebasename(self):
"""Pure base name of the path.""" """Pure base name of the path."""
return self._getbyspec("purebasename")[0] return self._getbyspec('purebasename')[0]
@property @property
def ext(self): def ext(self):
"""Extension of the path (including the '.').""" """Extension of the path (including the '.')."""
return self._getbyspec("ext")[0] return self._getbyspec('ext')[0]
def read_binary(self): def read_binary(self):
"""Read and return a bytestring from reading the path.""" """Read and return a bytestring from reading the path."""
with self.open("rb") as f: with self.open('rb') as f:
return f.read() return f.read()
def read_text(self, encoding): def read_text(self, encoding):
"""Read and return a Unicode string from reading the path.""" """Read and return a Unicode string from reading the path."""
with self.open("r", encoding=encoding) as f: with self.open('r', encoding=encoding) as f:
return f.read() return f.read()
def read(self, mode="r"): def read(self, mode='r'):
"""Read and return a bytestring from reading the path.""" """Read and return a bytestring from reading the path."""
with self.open(mode) as f: with self.open(mode) as f:
return f.read() return f.read()
@ -380,11 +380,11 @@ class LocalPath:
def readlines(self, cr=1): def readlines(self, cr=1):
"""Read and return a list of lines from the path. if cr is False, the """Read and return a list of lines from the path. if cr is False, the
newline will be removed from the end of each line.""" newline will be removed from the end of each line."""
mode = "r" mode = 'r'
if not cr: if not cr:
content = self.read(mode) content = self.read(mode)
return content.split("\n") return content.split('\n')
else: else:
f = self.open(mode) f = self.open(mode)
try: try:
@ -394,7 +394,7 @@ class LocalPath:
def load(self): def load(self):
"""(deprecated) return object unpickled from self.read()""" """(deprecated) return object unpickled from self.read()"""
f = self.open("rb") f = self.open('rb')
try: try:
import pickle import pickle
@ -405,7 +405,7 @@ class LocalPath:
def move(self, target): def move(self, target):
"""Move this path to target.""" """Move this path to target."""
if target.relto(self): if target.relto(self):
raise error.EINVAL(target, "cannot move path into a subdirectory of itself") raise error.EINVAL(target, 'cannot move path into a subdirectory of itself')
try: try:
self.rename(target) self.rename(target)
except error.EXDEV: # invalid cross-device link except error.EXDEV: # invalid cross-device link
@ -436,19 +436,19 @@ class LocalPath:
to the given 'relpath'. to the given 'relpath'.
""" """
if not isinstance(relpath, (str, LocalPath)): if not isinstance(relpath, (str, LocalPath)):
raise TypeError(f"{relpath!r}: not a string or path object") raise TypeError(f'{relpath!r}: not a string or path object')
strrelpath = str(relpath) strrelpath = str(relpath)
if strrelpath and strrelpath[-1] != self.sep: if strrelpath and strrelpath[-1] != self.sep:
strrelpath += self.sep strrelpath += self.sep
# assert strrelpath[-1] == self.sep # assert strrelpath[-1] == self.sep
# assert strrelpath[-2] != self.sep # assert strrelpath[-2] != self.sep
strself = self.strpath strself = self.strpath
if sys.platform == "win32" or getattr(os, "_name", None) == "nt": if sys.platform == 'win32' or getattr(os, '_name', None) == 'nt':
if os.path.normcase(strself).startswith(os.path.normcase(strrelpath)): if os.path.normcase(strself).startswith(os.path.normcase(strrelpath)):
return strself[len(strrelpath) :] return strself[len(strrelpath):]
elif strself.startswith(strrelpath): elif strself.startswith(strrelpath):
return strself[len(strrelpath) :] return strself[len(strrelpath):]
return "" return ''
def ensure_dir(self, *args): def ensure_dir(self, *args):
"""Ensure the path joined with args is a directory.""" """Ensure the path joined with args is a directory."""
@ -542,10 +542,10 @@ class LocalPath:
def _sortlist(self, res, sort): def _sortlist(self, res, sort):
if sort: if sort:
if hasattr(sort, "__call__"): if hasattr(sort, '__call__'):
warnings.warn( warnings.warn(
DeprecationWarning( DeprecationWarning(
"listdir(sort=callable) is deprecated and breaks on python3" 'listdir(sort=callable) is deprecated and breaks on python3',
), ),
stacklevel=3, stacklevel=3,
) )
@ -592,7 +592,7 @@ class LocalPath:
other = abspath(other) other = abspath(other)
if self == other: if self == other:
return True return True
if not hasattr(os.path, "samefile"): if not hasattr(os.path, 'samefile'):
return False return False
return error.checked_call(os.path.samefile, self.strpath, other) return error.checked_call(os.path.samefile, self.strpath, other)
@ -609,7 +609,7 @@ class LocalPath:
import shutil import shutil
error.checked_call( error.checked_call(
shutil.rmtree, self.strpath, ignore_errors=ignore_errors shutil.rmtree, self.strpath, ignore_errors=ignore_errors,
) )
else: else:
error.checked_call(os.rmdir, self.strpath) error.checked_call(os.rmdir, self.strpath)
@ -618,19 +618,19 @@ class LocalPath:
self.chmod(0o700) self.chmod(0o700)
error.checked_call(os.remove, self.strpath) error.checked_call(os.remove, self.strpath)
def computehash(self, hashtype="md5", chunksize=524288): def computehash(self, hashtype='md5', chunksize=524288):
"""Return hexdigest of hashvalue for this file.""" """Return hexdigest of hashvalue for this file."""
try: try:
try: try:
import hashlib as mod import hashlib as mod
except ImportError: except ImportError:
if hashtype == "sha1": if hashtype == 'sha1':
hashtype = "sha" hashtype = 'sha'
mod = __import__(hashtype) mod = __import__(hashtype)
hash = getattr(mod, hashtype)() hash = getattr(mod, hashtype)()
except (AttributeError, ImportError): except (AttributeError, ImportError):
raise ValueError(f"Don't know how to compute {hashtype!r} hash") raise ValueError(f"Don't know how to compute {hashtype!r} hash")
f = self.open("rb") f = self.open('rb')
try: try:
while 1: while 1:
buf = f.read(chunksize) buf = f.read(chunksize)
@ -656,28 +656,28 @@ class LocalPath:
obj.strpath = self.strpath obj.strpath = self.strpath
return obj return obj
drive, dirname, basename, purebasename, ext = self._getbyspec( drive, dirname, basename, purebasename, ext = self._getbyspec(
"drive,dirname,basename,purebasename,ext" 'drive,dirname,basename,purebasename,ext',
) )
if "basename" in kw: if 'basename' in kw:
if "purebasename" in kw or "ext" in kw: if 'purebasename' in kw or 'ext' in kw:
raise ValueError("invalid specification %r" % kw) raise ValueError('invalid specification %r' % kw)
else: else:
pb = kw.setdefault("purebasename", purebasename) pb = kw.setdefault('purebasename', purebasename)
try: try:
ext = kw["ext"] ext = kw['ext']
except KeyError: except KeyError:
pass pass
else: else:
if ext and not ext.startswith("."): if ext and not ext.startswith('.'):
ext = "." + ext ext = '.' + ext
kw["basename"] = pb + ext kw['basename'] = pb + ext
if "dirname" in kw and not kw["dirname"]: if 'dirname' in kw and not kw['dirname']:
kw["dirname"] = drive kw['dirname'] = drive
else: else:
kw.setdefault("dirname", dirname) kw.setdefault('dirname', dirname)
kw.setdefault("sep", self.sep) kw.setdefault('sep', self.sep)
obj.strpath = normpath("{dirname}{sep}{basename}".format(**kw)) obj.strpath = normpath('{dirname}{sep}{basename}'.format(**kw))
return obj return obj
def _getbyspec(self, spec: str) -> list[str]: def _getbyspec(self, spec: str) -> list[str]:
@ -685,28 +685,28 @@ class LocalPath:
res = [] res = []
parts = self.strpath.split(self.sep) parts = self.strpath.split(self.sep)
args = filter(None, spec.split(",")) args = filter(None, spec.split(','))
for name in args: for name in args:
if name == "drive": if name == 'drive':
res.append(parts[0]) res.append(parts[0])
elif name == "dirname": elif name == 'dirname':
res.append(self.sep.join(parts[:-1])) res.append(self.sep.join(parts[:-1]))
else: else:
basename = parts[-1] basename = parts[-1]
if name == "basename": if name == 'basename':
res.append(basename) res.append(basename)
else: else:
i = basename.rfind(".") i = basename.rfind('.')
if i == -1: if i == -1:
purebasename, ext = basename, "" purebasename, ext = basename, ''
else: else:
purebasename, ext = basename[:i], basename[i:] purebasename, ext = basename[:i], basename[i:]
if name == "purebasename": if name == 'purebasename':
res.append(purebasename) res.append(purebasename)
elif name == "ext": elif name == 'ext':
res.append(ext) res.append(ext)
else: else:
raise ValueError("invalid part specification %r" % name) raise ValueError('invalid part specification %r' % name)
return res return res
def dirpath(self, *args, **kwargs): def dirpath(self, *args, **kwargs):
@ -717,7 +717,7 @@ class LocalPath:
if args: if args:
path = path.join(*args) path = path.join(*args)
return path return path
return self.new(basename="").join(*args, **kwargs) return self.new(basename='').join(*args, **kwargs)
def join(self, *args: os.PathLike[str], abs: bool = False) -> LocalPath: def join(self, *args: os.PathLike[str], abs: bool = False) -> LocalPath:
"""Return a new path by appending all 'args' as path """Return a new path by appending all 'args' as path
@ -736,20 +736,20 @@ class LocalPath:
break break
newargs.insert(0, arg) newargs.insert(0, arg)
# special case for when we have e.g. strpath == "/" # special case for when we have e.g. strpath == "/"
actual_sep = "" if strpath.endswith(sep) else sep actual_sep = '' if strpath.endswith(sep) else sep
for arg in strargs: for arg in strargs:
arg = arg.strip(sep) arg = arg.strip(sep)
if iswin32: if iswin32:
# allow unix style paths even on windows. # allow unix style paths even on windows.
arg = arg.strip("/") arg = arg.strip('/')
arg = arg.replace("/", sep) arg = arg.replace('/', sep)
strpath = strpath + actual_sep + arg strpath = strpath + actual_sep + arg
actual_sep = sep actual_sep = sep
obj = object.__new__(self.__class__) obj = object.__new__(self.__class__)
obj.strpath = normpath(strpath) obj.strpath = normpath(strpath)
return obj return obj
def open(self, mode="r", ensure=False, encoding=None): def open(self, mode='r', ensure=False, encoding=None):
"""Return an opened file with the given mode. """Return an opened file with the given mode.
If ensure is True, create parent directories if needed. If ensure is True, create parent directories if needed.
@ -797,15 +797,15 @@ class LocalPath:
if not kw: if not kw:
return exists(self.strpath) return exists(self.strpath)
if len(kw) == 1: if len(kw) == 1:
if "dir" in kw: if 'dir' in kw:
return not kw["dir"] ^ isdir(self.strpath) return not kw['dir'] ^ isdir(self.strpath)
if "file" in kw: if 'file' in kw:
return not kw["file"] ^ isfile(self.strpath) return not kw['file'] ^ isfile(self.strpath)
if not kw: if not kw:
kw = {"exists": 1} kw = {'exists': 1}
return Checkers(self)._evaluate(kw) return Checkers(self)._evaluate(kw)
_patternchars = set("*?[" + os.sep) _patternchars = set('*?[' + os.sep)
def listdir(self, fil=None, sort=None): def listdir(self, fil=None, sort=None):
"""List directory contents, possibly filter by the given fil func """List directory contents, possibly filter by the given fil func
@ -882,7 +882,7 @@ class LocalPath:
def dump(self, obj, bin=1): def dump(self, obj, bin=1):
"""Pickle object into path location""" """Pickle object into path location"""
f = self.open("wb") f = self.open('wb')
import pickle import pickle
try: try:
@ -902,7 +902,7 @@ class LocalPath:
""" """
if ensure: if ensure:
self.dirpath().ensure(dir=1) self.dirpath().ensure(dir=1)
with self.open("wb") as f: with self.open('wb') as f:
f.write(data) f.write(data)
def write_text(self, data, encoding, ensure=False): def write_text(self, data, encoding, ensure=False):
@ -911,18 +911,18 @@ class LocalPath:
""" """
if ensure: if ensure:
self.dirpath().ensure(dir=1) self.dirpath().ensure(dir=1)
with self.open("w", encoding=encoding) as f: with self.open('w', encoding=encoding) as f:
f.write(data) f.write(data)
def write(self, data, mode="w", ensure=False): def write(self, data, mode='w', ensure=False):
"""Write data into path. If ensure is True create """Write data into path. If ensure is True create
missing parent directories. missing parent directories.
""" """
if ensure: if ensure:
self.dirpath().ensure(dir=1) self.dirpath().ensure(dir=1)
if "b" in mode: if 'b' in mode:
if not isinstance(data, bytes): if not isinstance(data, bytes):
raise ValueError("can only process bytes") raise ValueError('can only process bytes')
else: else:
if not isinstance(data, str): if not isinstance(data, str):
if not isinstance(data, bytes): if not isinstance(data, bytes):
@ -957,12 +957,12 @@ class LocalPath:
then the path is forced to be a directory path. then the path is forced to be a directory path.
""" """
p = self.join(*args) p = self.join(*args)
if kwargs.get("dir", 0): if kwargs.get('dir', 0):
return p._ensuredirs() return p._ensuredirs()
else: else:
p.dirpath()._ensuredirs() p.dirpath()._ensuredirs()
if not p.check(file=1): if not p.check(file=1):
p.open("wb").close() p.open('wb').close()
return p return p
@overload @overload
@ -1033,7 +1033,7 @@ class LocalPath:
return self.stat().atime return self.stat().atime
def __repr__(self): def __repr__(self):
return "local(%r)" % self.strpath return 'local(%r)' % self.strpath
def __str__(self): def __str__(self):
"""Return string representation of the Path.""" """Return string representation of the Path."""
@ -1045,7 +1045,7 @@ class LocalPath:
if rec is True perform recursively. if rec is True perform recursively.
""" """
if not isinstance(mode, int): if not isinstance(mode, int):
raise TypeError(f"mode {mode!r} must be an integer") raise TypeError(f'mode {mode!r} must be an integer')
if rec: if rec:
for x in self.visit(rec=rec): for x in self.visit(rec=rec):
error.checked_call(os.chmod, str(x), mode) error.checked_call(os.chmod, str(x), mode)
@ -1059,7 +1059,7 @@ class LocalPath:
pkgpath = None pkgpath = None
for parent in self.parts(reverse=True): for parent in self.parts(reverse=True):
if parent.isdir(): if parent.isdir():
if not parent.join("__init__.py").exists(): if not parent.join('__init__.py').exists():
break break
if not isimportable(parent.basename): if not isimportable(parent.basename):
break break
@ -1069,7 +1069,7 @@ class LocalPath:
def _ensuresyspath(self, ensuremode, path): def _ensuresyspath(self, ensuremode, path):
if ensuremode: if ensuremode:
s = str(path) s = str(path)
if ensuremode == "append": if ensuremode == 'append':
if s not in sys.path: if s not in sys.path:
sys.path.append(s) sys.path.append(s)
else: else:
@ -1100,7 +1100,7 @@ class LocalPath:
if not self.check(): if not self.check():
raise error.ENOENT(self) raise error.ENOENT(self)
if ensuresyspath == "importlib": if ensuresyspath == 'importlib':
if modname is None: if modname is None:
modname = self.purebasename modname = self.purebasename
spec = importlib.util.spec_from_file_location(modname, str(self)) spec = importlib.util.spec_from_file_location(modname, str(self))
@ -1115,10 +1115,10 @@ class LocalPath:
pkgpath = self.pypkgpath() pkgpath = self.pypkgpath()
if pkgpath is not None: if pkgpath is not None:
pkgroot = pkgpath.dirpath() pkgroot = pkgpath.dirpath()
names = self.new(ext="").relto(pkgroot).split(self.sep) names = self.new(ext='').relto(pkgroot).split(self.sep)
if names[-1] == "__init__": if names[-1] == '__init__':
names.pop() names.pop()
modname = ".".join(names) modname = '.'.join(names)
else: else:
pkgroot = self.dirpath() pkgroot = self.dirpath()
modname = self.purebasename modname = self.purebasename
@ -1126,25 +1126,25 @@ class LocalPath:
self._ensuresyspath(ensuresyspath, pkgroot) self._ensuresyspath(ensuresyspath, pkgroot)
__import__(modname) __import__(modname)
mod = sys.modules[modname] mod = sys.modules[modname]
if self.basename == "__init__.py": if self.basename == '__init__.py':
return mod # we don't check anything as we might return mod # we don't check anything as we might
# be in a namespace package ... too icky to check # be in a namespace package ... too icky to check
modfile = mod.__file__ modfile = mod.__file__
assert modfile is not None assert modfile is not None
if modfile[-4:] in (".pyc", ".pyo"): if modfile[-4:] in ('.pyc', '.pyo'):
modfile = modfile[:-1] modfile = modfile[:-1]
elif modfile.endswith("$py.class"): elif modfile.endswith('$py.class'):
modfile = modfile[:-9] + ".py" modfile = modfile[:-9] + '.py'
if modfile.endswith(os.sep + "__init__.py"): if modfile.endswith(os.sep + '__init__.py'):
if self.basename != "__init__.py": if self.basename != '__init__.py':
modfile = modfile[:-12] modfile = modfile[:-12]
try: try:
issame = self.samefile(modfile) issame = self.samefile(modfile)
except error.ENOENT: except error.ENOENT:
issame = False issame = False
if not issame: if not issame:
ignore = os.getenv("PY_IGNORE_IMPORTMISMATCH") ignore = os.getenv('PY_IGNORE_IMPORTMISMATCH')
if ignore != "1": if ignore != '1':
raise self.ImportMismatchError(modname, modfile, self) raise self.ImportMismatchError(modname, modfile, self)
return mod return mod
else: else:
@ -1158,7 +1158,7 @@ class LocalPath:
mod.__file__ = str(self) mod.__file__ = str(self)
sys.modules[modname] = mod sys.modules[modname] = mod
try: try:
with open(str(self), "rb") as f: with open(str(self), 'rb') as f:
exec(f.read(), mod.__dict__) exec(f.read(), mod.__dict__)
except BaseException: except BaseException:
del sys.modules[modname] del sys.modules[modname]
@ -1173,8 +1173,8 @@ class LocalPath:
from subprocess import PIPE from subprocess import PIPE
from subprocess import Popen from subprocess import Popen
popen_opts.pop("stdout", None) popen_opts.pop('stdout', None)
popen_opts.pop("stderr", None) popen_opts.pop('stderr', None)
proc = Popen( proc = Popen(
[str(self)] + [str(arg) for arg in argv], [str(self)] + [str(arg) for arg in argv],
**popen_opts, **popen_opts,
@ -1214,23 +1214,23 @@ class LocalPath:
else: else:
if paths is None: if paths is None:
if iswin32: if iswin32:
paths = os.environ["Path"].split(";") paths = os.environ['Path'].split(';')
if "" not in paths and "." not in paths: if '' not in paths and '.' not in paths:
paths.append(".") paths.append('.')
try: try:
systemroot = os.environ["SYSTEMROOT"] systemroot = os.environ['SYSTEMROOT']
except KeyError: except KeyError:
pass pass
else: else:
paths = [ paths = [
path.replace("%SystemRoot%", systemroot) for path in paths path.replace('%SystemRoot%', systemroot) for path in paths
] ]
else: else:
paths = os.environ["PATH"].split(":") paths = os.environ['PATH'].split(':')
tryadd = [] tryadd = []
if iswin32: if iswin32:
tryadd += os.environ["PATHEXT"].split(os.pathsep) tryadd += os.environ['PATHEXT'].split(os.pathsep)
tryadd.append("") tryadd.append('')
for x in paths: for x in paths:
for addext in tryadd: for addext in tryadd:
@ -1248,10 +1248,10 @@ class LocalPath:
@classmethod @classmethod
def _gethomedir(cls): def _gethomedir(cls):
try: try:
x = os.environ["HOME"] x = os.environ['HOME']
except KeyError: except KeyError:
try: try:
x = os.environ["HOMEDRIVE"] + os.environ["HOMEPATH"] x = os.environ['HOMEDRIVE'] + os.environ['HOMEPATH']
except KeyError: except KeyError:
return None return None
return cls(x) return cls(x)
@ -1288,7 +1288,7 @@ class LocalPath:
@classmethod @classmethod
def make_numbered_dir( def make_numbered_dir(
cls, prefix="session-", rootdir=None, keep=3, lock_timeout=172800 cls, prefix='session-', rootdir=None, keep=3, lock_timeout=172800,
): # two days ): # two days
"""Return unique directory with a number greater than the current """Return unique directory with a number greater than the current
maximum one. The number is assumed to start directly after prefix. maximum one. The number is assumed to start directly after prefix.
@ -1306,21 +1306,21 @@ class LocalPath:
nbasename = path.basename.lower() nbasename = path.basename.lower()
if nbasename.startswith(nprefix): if nbasename.startswith(nprefix):
try: try:
return int(nbasename[len(nprefix) :]) return int(nbasename[len(nprefix):])
except ValueError: except ValueError:
pass pass
def create_lockfile(path): def create_lockfile(path):
"""Exclusively create lockfile. Throws when failed""" """Exclusively create lockfile. Throws when failed"""
mypid = os.getpid() mypid = os.getpid()
lockfile = path.join(".lock") lockfile = path.join('.lock')
if hasattr(lockfile, "mksymlinkto"): if hasattr(lockfile, 'mksymlinkto'):
lockfile.mksymlinkto(str(mypid)) lockfile.mksymlinkto(str(mypid))
else: else:
fd = error.checked_call( fd = error.checked_call(
os.open, str(lockfile), os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644 os.open, str(lockfile), os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644,
) )
with os.fdopen(fd, "w") as f: with os.fdopen(fd, 'w') as f:
f.write(str(mypid)) f.write(str(mypid))
return lockfile return lockfile
@ -1380,7 +1380,7 @@ class LocalPath:
except error.Error: except error.Error:
pass pass
garbage_prefix = prefix + "garbage-" garbage_prefix = prefix + 'garbage-'
def is_garbage(path): def is_garbage(path):
"""Check if path denotes directory scheduled for removal""" """Check if path denotes directory scheduled for removal"""
@ -1428,15 +1428,15 @@ class LocalPath:
# make link... # make link...
try: try:
username = os.environ["USER"] # linux, et al username = os.environ['USER'] # linux, et al
except KeyError: except KeyError:
try: try:
username = os.environ["USERNAME"] # windows username = os.environ['USERNAME'] # windows
except KeyError: except KeyError:
username = "current" username = 'current'
src = str(udir) src = str(udir)
dest = src[: src.rfind("-")] + "-" + username dest = src[: src.rfind('-')] + '-' + username
try: try:
os.unlink(dest) os.unlink(dest)
except OSError: except OSError:
@ -1466,9 +1466,9 @@ def copystat(src, dest):
def copychunked(src, dest): def copychunked(src, dest):
chunksize = 524288 # half a meg of bytes chunksize = 524288 # half a meg of bytes
fsrc = src.open("rb") fsrc = src.open('rb')
try: try:
fdest = dest.open("wb") fdest = dest.open('wb')
try: try:
while 1: while 1:
buf = fsrc.read(chunksize) buf = fsrc.read(chunksize)
@ -1482,8 +1482,8 @@ def copychunked(src, dest):
def isimportable(name): def isimportable(name):
if name and (name[0].isalpha() or name[0] == "_"): if name and (name[0].isalpha() or name[0] == '_'):
name = name.replace("_", "") name = name.replace('_', '')
return not name or name.isalnum() return not name or name.isalnum()

View file

@ -1,5 +1,6 @@
# file generated by setuptools_scm # file generated by setuptools_scm
# don't change, don't track in version control # don't change, don't track in version control
from __future__ import annotations
TYPE_CHECKING = False TYPE_CHECKING = False
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Tuple, Union from typing import Tuple, Union

View file

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Support for presenting detailed information in failing assertions.""" """Support for presenting detailed information in failing assertions."""
from __future__ import annotations
import sys import sys
from typing import Any from typing import Any
from typing import Generator from typing import Generator
@ -22,34 +24,34 @@ if TYPE_CHECKING:
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig") group = parser.getgroup('debugconfig')
group.addoption( group.addoption(
"--assert", '--assert',
action="store", action='store',
dest="assertmode", dest='assertmode',
choices=("rewrite", "plain"), choices=('rewrite', 'plain'),
default="rewrite", default='rewrite',
metavar="MODE", metavar='MODE',
help=( help=(
"Control assertion debugging tools.\n" 'Control assertion debugging tools.\n'
"'plain' performs no assertion debugging.\n" "'plain' performs no assertion debugging.\n"
"'rewrite' (the default) rewrites assert statements in test modules" "'rewrite' (the default) rewrites assert statements in test modules"
" on import to provide assert expression information." ' on import to provide assert expression information.'
), ),
) )
parser.addini( parser.addini(
"enable_assertion_pass_hook", 'enable_assertion_pass_hook',
type="bool", type='bool',
default=False, default=False,
help="Enables the pytest_assertion_pass hook. " help='Enables the pytest_assertion_pass hook. '
"Make sure to delete any previously generated pyc cache files.", 'Make sure to delete any previously generated pyc cache files.',
) )
Config._add_verbosity_ini( Config._add_verbosity_ini(
parser, parser,
Config.VERBOSITY_ASSERTIONS, Config.VERBOSITY_ASSERTIONS,
help=( help=(
"Specify a verbosity level for assertions, overriding the main level. " 'Specify a verbosity level for assertions, overriding the main level. '
"Higher levels will provide more detailed explanation when an assertion fails." 'Higher levels will provide more detailed explanation when an assertion fails.'
), ),
) )
@ -67,7 +69,7 @@ def register_assert_rewrite(*names: str) -> None:
""" """
for name in names: for name in names:
if not isinstance(name, str): if not isinstance(name, str):
msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable] msg = 'expected module names as *args, got {0} instead' # type: ignore[unreachable]
raise TypeError(msg.format(repr(names))) raise TypeError(msg.format(repr(names)))
for hook in sys.meta_path: for hook in sys.meta_path:
if isinstance(hook, rewrite.AssertionRewritingHook): if isinstance(hook, rewrite.AssertionRewritingHook):
@ -92,16 +94,16 @@ class AssertionState:
def __init__(self, config: Config, mode) -> None: def __init__(self, config: Config, mode) -> None:
self.mode = mode self.mode = mode
self.trace = config.trace.root.get("assertion") self.trace = config.trace.root.get('assertion')
self.hook: Optional[rewrite.AssertionRewritingHook] = None self.hook: rewrite.AssertionRewritingHook | None = None
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook: def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails.""" """Try to install the rewrite hook, raise SystemError if it fails."""
config.stash[assertstate_key] = AssertionState(config, "rewrite") config.stash[assertstate_key] = AssertionState(config, 'rewrite')
config.stash[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config) config.stash[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook) sys.meta_path.insert(0, hook)
config.stash[assertstate_key].trace("installed rewrite import hook") config.stash[assertstate_key].trace('installed rewrite import hook')
def undo() -> None: def undo() -> None:
hook = config.stash[assertstate_key].hook hook = config.stash[assertstate_key].hook
@ -112,7 +114,7 @@ def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
return hook return hook
def pytest_collection(session: "Session") -> None: def pytest_collection(session: Session) -> None:
# This hook is only called when test modules are collected # This hook is only called when test modules are collected
# so for example not in the managing process of pytest-xdist # so for example not in the managing process of pytest-xdist
# (which does not collect test modules). # (which does not collect test modules).
@ -132,7 +134,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
""" """
ihook = item.ihook ihook = item.ihook
def callbinrepr(op, left: object, right: object) -> Optional[str]: def callbinrepr(op, left: object, right: object) -> str | None:
"""Call the pytest_assertrepr_compare hook and prepare the result. """Call the pytest_assertrepr_compare hook and prepare the result.
This uses the first result from the hook and then ensures the This uses the first result from the hook and then ensures the
@ -148,15 +150,15 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
pretty printing. pretty printing.
""" """
hook_result = ihook.pytest_assertrepr_compare( hook_result = ihook.pytest_assertrepr_compare(
config=item.config, op=op, left=left, right=right config=item.config, op=op, left=left, right=right,
) )
for new_expl in hook_result: for new_expl in hook_result:
if new_expl: if new_expl:
new_expl = truncate.truncate_if_required(new_expl, item) new_expl = truncate.truncate_if_required(new_expl, item)
new_expl = [line.replace("\n", "\\n") for line in new_expl] new_expl = [line.replace('\n', '\\n') for line in new_expl]
res = "\n~".join(new_expl) res = '\n~'.join(new_expl)
if item.config.getvalue("assertmode") == "rewrite": if item.config.getvalue('assertmode') == 'rewrite':
res = res.replace("%", "%%") res = res.replace('%', '%%')
return res return res
return None return None
@ -178,7 +180,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
util._config = None util._config = None
def pytest_sessionfinish(session: "Session") -> None: def pytest_sessionfinish(session: Session) -> None:
assertstate = session.config.stash.get(assertstate_key, None) assertstate = session.config.stash.get(assertstate_key, None)
if assertstate: if assertstate:
if assertstate.hook is not None: if assertstate.hook is not None:
@ -186,6 +188,6 @@ def pytest_sessionfinish(session: "Session") -> None:
def pytest_assertrepr_compare( def pytest_assertrepr_compare(
config: Config, op: str, left: Any, right: Any config: Config, op: str, left: Any, right: Any,
) -> Optional[List[str]]: ) -> list[str] | None:
return util.assertrepr_compare(config=config, op=op, left=left, right=right) return util.assertrepr_compare(config=config, op=op, left=left, right=right)

View file

@ -3,6 +3,7 @@
Current default behaviour is to truncate assertion explanations at Current default behaviour is to truncate assertion explanations at
terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI. terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI.
""" """
from __future__ import annotations
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -18,8 +19,8 @@ USAGE_MSG = "use '-vv' to show"
def truncate_if_required( def truncate_if_required(
explanation: List[str], item: Item, max_length: Optional[int] = None explanation: list[str], item: Item, max_length: int | None = None,
) -> List[str]: ) -> list[str]:
"""Truncate this assertion explanation if the given test item is eligible.""" """Truncate this assertion explanation if the given test item is eligible."""
if _should_truncate_item(item): if _should_truncate_item(item):
return _truncate_explanation(explanation) return _truncate_explanation(explanation)
@ -33,10 +34,10 @@ def _should_truncate_item(item: Item) -> bool:
def _truncate_explanation( def _truncate_explanation(
input_lines: List[str], input_lines: list[str],
max_lines: Optional[int] = None, max_lines: int | None = None,
max_chars: Optional[int] = None, max_chars: int | None = None,
) -> List[str]: ) -> list[str]:
"""Truncate given list of strings that makes up the assertion explanation. """Truncate given list of strings that makes up the assertion explanation.
Truncates to either 8 lines, or 640 characters - whichever the input reaches Truncates to either 8 lines, or 640 characters - whichever the input reaches
@ -49,7 +50,7 @@ def _truncate_explanation(
max_chars = DEFAULT_MAX_CHARS max_chars = DEFAULT_MAX_CHARS
# Check if truncation required # Check if truncation required
input_char_count = len("".join(input_lines)) input_char_count = len(''.join(input_lines))
# The length of the truncation explanation depends on the number of lines # The length of the truncation explanation depends on the number of lines
# removed but is at least 68 characters: # removed but is at least 68 characters:
# The real value is # The real value is
@ -67,17 +68,17 @@ def _truncate_explanation(
# The truncation explanation add two lines to the output # The truncation explanation add two lines to the output
tolerable_max_lines = max_lines + 2 tolerable_max_lines = max_lines + 2
if ( if (
len(input_lines) <= tolerable_max_lines len(input_lines) <= tolerable_max_lines and
and input_char_count <= tolerable_max_chars input_char_count <= tolerable_max_chars
): ):
return input_lines return input_lines
# Truncate first to max_lines, and then truncate to max_chars if necessary # Truncate first to max_lines, and then truncate to max_chars if necessary
truncated_explanation = input_lines[:max_lines] truncated_explanation = input_lines[:max_lines]
truncated_char = True truncated_char = True
# We reevaluate the need to truncate chars following removal of some lines # We reevaluate the need to truncate chars following removal of some lines
if len("".join(truncated_explanation)) > tolerable_max_chars: if len(''.join(truncated_explanation)) > tolerable_max_chars:
truncated_explanation = _truncate_by_char_count( truncated_explanation = _truncate_by_char_count(
truncated_explanation, max_chars truncated_explanation, max_chars,
) )
else: else:
truncated_char = False truncated_char = False
@ -85,22 +86,22 @@ def _truncate_explanation(
truncated_line_count = len(input_lines) - len(truncated_explanation) truncated_line_count = len(input_lines) - len(truncated_explanation)
if truncated_explanation[-1]: if truncated_explanation[-1]:
# Add ellipsis and take into account part-truncated final line # Add ellipsis and take into account part-truncated final line
truncated_explanation[-1] = truncated_explanation[-1] + "..." truncated_explanation[-1] = truncated_explanation[-1] + '...'
if truncated_char: if truncated_char:
# It's possible that we did not remove any char from this line # It's possible that we did not remove any char from this line
truncated_line_count += 1 truncated_line_count += 1
else: else:
# Add proper ellipsis when we were able to fit a full line exactly # Add proper ellipsis when we were able to fit a full line exactly
truncated_explanation[-1] = "..." truncated_explanation[-1] = '...'
return [ return [
*truncated_explanation, *truncated_explanation,
"", '',
f"...Full output truncated ({truncated_line_count} line" f'...Full output truncated ({truncated_line_count} line'
f"{'' if truncated_line_count == 1 else 's'} hidden), {USAGE_MSG}", f"{'' if truncated_line_count == 1 else 's'} hidden), {USAGE_MSG}",
] ]
def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]: def _truncate_by_char_count(input_lines: list[str], max_chars: int) -> list[str]:
# Find point at which input length exceeds total allowed length # Find point at which input length exceeds total allowed length
iterated_char_count = 0 iterated_char_count = 0
for iterated_index, input_line in enumerate(input_lines): for iterated_index, input_line in enumerate(input_lines):

View file

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Utilities for assertion debugging.""" """Utilities for assertion debugging."""
from __future__ import annotations
import collections.abc import collections.abc
import os import os
import pprint import pprint
@ -15,8 +17,8 @@ from typing import Protocol
from typing import Sequence from typing import Sequence
from unicodedata import normalize from unicodedata import normalize
from _pytest import outcomes
import _pytest._code import _pytest._code
from _pytest import outcomes
from _pytest._io.pprint import PrettyPrinter from _pytest._io.pprint import PrettyPrinter
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest._io.saferepr import saferepr_unlimited from _pytest._io.saferepr import saferepr_unlimited
@ -27,18 +29,18 @@ from _pytest.config import Config
# interpretation code and assertion rewriter to detect this plugin was # interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the # loaded and in turn call the hooks defined here as part of the
# DebugInterpreter. # DebugInterpreter.
_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None _reprcompare: Callable[[str, object, object], str | None] | None = None
# Works similarly as _reprcompare attribute. Is populated with the hook call # Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called. # when pytest_runtest_setup is called.
_assertion_pass: Optional[Callable[[int, str, str], None]] = None _assertion_pass: Callable[[int, str, str], None] | None = None
# Config object which is assigned during pytest_runtest_protocol. # Config object which is assigned during pytest_runtest_protocol.
_config: Optional[Config] = None _config: Config | None = None
class _HighlightFunc(Protocol): class _HighlightFunc(Protocol):
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str: def __call__(self, source: str, lexer: Literal['diff', 'python'] = 'python') -> str:
"""Apply highlighting to the given source.""" """Apply highlighting to the given source."""
@ -54,27 +56,27 @@ def format_explanation(explanation: str) -> str:
""" """
lines = _split_explanation(explanation) lines = _split_explanation(explanation)
result = _format_lines(lines) result = _format_lines(lines)
return "\n".join(result) return '\n'.join(result)
def _split_explanation(explanation: str) -> List[str]: def _split_explanation(explanation: str) -> list[str]:
r"""Return a list of individual lines in the explanation. r"""Return a list of individual lines in the explanation.
This will return a list of lines split on '\n{', '\n}' and '\n~'. This will return a list of lines split on '\n{', '\n}' and '\n~'.
Any other newlines will be escaped and appear in the line as the Any other newlines will be escaped and appear in the line as the
literal '\n' characters. literal '\n' characters.
""" """
raw_lines = (explanation or "").split("\n") raw_lines = (explanation or '').split('\n')
lines = [raw_lines[0]] lines = [raw_lines[0]]
for values in raw_lines[1:]: for values in raw_lines[1:]:
if values and values[0] in ["{", "}", "~", ">"]: if values and values[0] in ['{', '}', '~', '>']:
lines.append(values) lines.append(values)
else: else:
lines[-1] += "\\n" + values lines[-1] += '\\n' + values
return lines return lines
def _format_lines(lines: Sequence[str]) -> List[str]: def _format_lines(lines: Sequence[str]) -> list[str]:
"""Format the individual lines. """Format the individual lines.
This will replace the '{', '}' and '~' characters of our mini formatting This will replace the '{', '}' and '~' characters of our mini formatting
@ -87,24 +89,24 @@ def _format_lines(lines: Sequence[str]) -> List[str]:
stack = [0] stack = [0]
stackcnt = [0] stackcnt = [0]
for line in lines[1:]: for line in lines[1:]:
if line.startswith("{"): if line.startswith('{'):
if stackcnt[-1]: if stackcnt[-1]:
s = "and " s = 'and '
else: else:
s = "where " s = 'where '
stack.append(len(result)) stack.append(len(result))
stackcnt[-1] += 1 stackcnt[-1] += 1
stackcnt.append(0) stackcnt.append(0)
result.append(" +" + " " * (len(stack) - 1) + s + line[1:]) result.append(' +' + ' ' * (len(stack) - 1) + s + line[1:])
elif line.startswith("}"): elif line.startswith('}'):
stack.pop() stack.pop()
stackcnt.pop() stackcnt.pop()
result[stack[-1]] += line[1:] result[stack[-1]] += line[1:]
else: else:
assert line[0] in ["~", ">"] assert line[0] in ['~', '>']
stack[-1] += 1 stack[-1] += 1
indent = len(stack) if line.startswith("~") else len(stack) - 1 indent = len(stack) if line.startswith('~') else len(stack) - 1
result.append(" " * indent + line[1:]) result.append(' ' * indent + line[1:])
assert len(stack) == 1 assert len(stack) == 1
return result return result
@ -126,15 +128,15 @@ def isset(x: Any) -> bool:
def isnamedtuple(obj: Any) -> bool: def isnamedtuple(obj: Any) -> bool:
return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None return isinstance(obj, tuple) and getattr(obj, '_fields', None) is not None
def isdatacls(obj: Any) -> bool: def isdatacls(obj: Any) -> bool:
return getattr(obj, "__dataclass_fields__", None) is not None return getattr(obj, '__dataclass_fields__', None) is not None
def isattrs(obj: Any) -> bool: def isattrs(obj: Any) -> bool:
return getattr(obj, "__attrs_attrs__", None) is not None return getattr(obj, '__attrs_attrs__', None) is not None
def isiterable(obj: Any) -> bool: def isiterable(obj: Any) -> bool:
@ -156,28 +158,28 @@ def has_default_eq(
for dataclasses the default co_filename is <string>, for attrs class, the __eq__ should contain "attrs eq generated" for dataclasses the default co_filename is <string>, for attrs class, the __eq__ should contain "attrs eq generated"
""" """
# inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68 # inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68
if hasattr(obj.__eq__, "__code__") and hasattr(obj.__eq__.__code__, "co_filename"): if hasattr(obj.__eq__, '__code__') and hasattr(obj.__eq__.__code__, 'co_filename'):
code_filename = obj.__eq__.__code__.co_filename code_filename = obj.__eq__.__code__.co_filename
if isattrs(obj): if isattrs(obj):
return "attrs generated eq" in code_filename return 'attrs generated eq' in code_filename
return code_filename == "<string>" # data class return code_filename == '<string>' # data class
return True return True
def assertrepr_compare( def assertrepr_compare(
config, op: str, left: Any, right: Any, use_ascii: bool = False config, op: str, left: Any, right: Any, use_ascii: bool = False,
) -> Optional[List[str]]: ) -> list[str] | None:
"""Return specialised explanations for some operators/operands.""" """Return specialised explanations for some operators/operands."""
verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS) verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
# Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier. # Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier.
# See issue #3246. # See issue #3246.
use_ascii = ( use_ascii = (
isinstance(left, str) isinstance(left, str) and
and isinstance(right, str) isinstance(right, str) and
and normalize("NFD", left) == normalize("NFD", right) normalize('NFD', left) == normalize('NFD', right)
) )
if verbose > 1: if verbose > 1:
@ -193,29 +195,29 @@ def assertrepr_compare(
left_repr = saferepr(left, maxsize=maxsize, use_ascii=use_ascii) left_repr = saferepr(left, maxsize=maxsize, use_ascii=use_ascii)
right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii) right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii)
summary = f"{left_repr} {op} {right_repr}" summary = f'{left_repr} {op} {right_repr}'
highlighter = config.get_terminal_writer()._highlight highlighter = config.get_terminal_writer()._highlight
explanation = None explanation = None
try: try:
if op == "==": if op == '==':
explanation = _compare_eq_any(left, right, highlighter, verbose) explanation = _compare_eq_any(left, right, highlighter, verbose)
elif op == "not in": elif op == 'not in':
if istext(left) and istext(right): if istext(left) and istext(right):
explanation = _notin_text(left, right, verbose) explanation = _notin_text(left, right, verbose)
elif op == "!=": elif op == '!=':
if isset(left) and isset(right): if isset(left) and isset(right):
explanation = ["Both sets are equal"] explanation = ['Both sets are equal']
elif op == ">=": elif op == '>=':
if isset(left) and isset(right): if isset(left) and isset(right):
explanation = _compare_gte_set(left, right, highlighter, verbose) explanation = _compare_gte_set(left, right, highlighter, verbose)
elif op == "<=": elif op == '<=':
if isset(left) and isset(right): if isset(left) and isset(right):
explanation = _compare_lte_set(left, right, highlighter, verbose) explanation = _compare_lte_set(left, right, highlighter, verbose)
elif op == ">": elif op == '>':
if isset(left) and isset(right): if isset(left) and isset(right):
explanation = _compare_gt_set(left, right, highlighter, verbose) explanation = _compare_gt_set(left, right, highlighter, verbose)
elif op == "<": elif op == '<':
if isset(left) and isset(right): if isset(left) and isset(right):
explanation = _compare_lt_set(left, right, highlighter, verbose) explanation = _compare_lt_set(left, right, highlighter, verbose)
@ -223,23 +225,23 @@ def assertrepr_compare(
raise raise
except Exception: except Exception:
explanation = [ explanation = [
"(pytest_assertion plugin: representation of details failed: {}.".format( '(pytest_assertion plugin: representation of details failed: {}.'.format(
_pytest._code.ExceptionInfo.from_current()._getreprcrash() _pytest._code.ExceptionInfo.from_current()._getreprcrash(),
), ),
" Probably an object has a faulty __repr__.)", ' Probably an object has a faulty __repr__.)',
] ]
if not explanation: if not explanation:
return None return None
if explanation[0] != "": if explanation[0] != '':
explanation = ["", *explanation] explanation = ['', *explanation]
return [summary, *explanation] return [summary, *explanation]
def _compare_eq_any( def _compare_eq_any(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0 left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0,
) -> List[str]: ) -> list[str]:
explanation = [] explanation = []
if istext(left) and istext(right): if istext(left) and istext(right):
explanation = _diff_text(left, right, verbose) explanation = _diff_text(left, right, verbose)
@ -274,7 +276,7 @@ def _compare_eq_any(
return explanation return explanation
def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: def _diff_text(left: str, right: str, verbose: int = 0) -> list[str]:
"""Return the explanation for the diff between text. """Return the explanation for the diff between text.
Unless --verbose is used this will skip leading and trailing Unless --verbose is used this will skip leading and trailing
@ -282,7 +284,7 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
""" """
from difflib import ndiff from difflib import ndiff
explanation: List[str] = [] explanation: list[str] = []
if verbose < 1: if verbose < 1:
i = 0 # just in case left or right has zero length i = 0 # just in case left or right has zero length
@ -292,7 +294,7 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
if i > 42: if i > 42:
i -= 10 # Provide some context i -= 10 # Provide some context
explanation = [ explanation = [
"Skipping %s identical leading characters in diff, use -v to show" % i 'Skipping %s identical leading characters in diff, use -v to show' % i,
] ]
left = left[i:] left = left[i:]
right = right[i:] right = right[i:]
@ -303,8 +305,8 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
if i > 42: if i > 42:
i -= 10 # Provide some context i -= 10 # Provide some context
explanation += [ explanation += [
f"Skipping {i} identical trailing " f'Skipping {i} identical trailing '
"characters in diff, use -v to show" 'characters in diff, use -v to show',
] ]
left = left[:-i] left = left[:-i]
right = right[:-i] right = right[:-i]
@ -312,11 +314,11 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
if left.isspace() or right.isspace(): if left.isspace() or right.isspace():
left = repr(str(left)) left = repr(str(left))
right = repr(str(right)) right = repr(str(right))
explanation += ["Strings contain only whitespace, escaping them using repr()"] explanation += ['Strings contain only whitespace, escaping them using repr()']
# "right" is the expected base against which we compare "left", # "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333 # see https://github.com/pytest-dev/pytest/issues/3333
explanation += [ explanation += [
line.strip("\n") line.strip('\n')
for line in ndiff(right.splitlines(keepends), left.splitlines(keepends)) for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
] ]
return explanation return explanation
@ -327,26 +329,26 @@ def _compare_eq_iterable(
right: Iterable[Any], right: Iterable[Any],
highligher: _HighlightFunc, highligher: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
if verbose <= 0 and not running_on_ci(): if verbose <= 0 and not running_on_ci():
return ["Use -v to get more diff"] return ['Use -v to get more diff']
# dynamic import to speedup pytest # dynamic import to speedup pytest
import difflib import difflib
left_formatting = PrettyPrinter().pformat(left).splitlines() left_formatting = PrettyPrinter().pformat(left).splitlines()
right_formatting = PrettyPrinter().pformat(right).splitlines() right_formatting = PrettyPrinter().pformat(right).splitlines()
explanation = ["", "Full diff:"] explanation = ['', 'Full diff:']
# "right" is the expected base against which we compare "left", # "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333 # see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend( explanation.extend(
highligher( highligher(
"\n".join( '\n'.join(
line.rstrip() line.rstrip()
for line in difflib.ndiff(right_formatting, left_formatting) for line in difflib.ndiff(right_formatting, left_formatting)
), ),
lexer="diff", lexer='diff',
).splitlines() ).splitlines(),
) )
return explanation return explanation
@ -356,9 +358,9 @@ def _compare_eq_sequence(
right: Sequence[Any], right: Sequence[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation: List[str] = [] explanation: list[str] = []
len_left = len(left) len_left = len(left)
len_right = len(right) len_right = len(right)
for i in range(min(len_left, len_right)): for i in range(min(len_left, len_right)):
@ -372,15 +374,15 @@ def _compare_eq_sequence(
# 102 # 102
# >>> s[0:1] # >>> s[0:1]
# b'f' # b'f'
left_value = left[i : i + 1] left_value = left[i: i + 1]
right_value = right[i : i + 1] right_value = right[i: i + 1]
else: else:
left_value = left[i] left_value = left[i]
right_value = right[i] right_value = right[i]
explanation.append( explanation.append(
f"At index {i} diff:" f'At index {i} diff:'
f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}" f' {highlighter(repr(left_value))} != {highlighter(repr(right_value))}',
) )
break break
@ -393,21 +395,21 @@ def _compare_eq_sequence(
len_diff = len_left - len_right len_diff = len_left - len_right
if len_diff: if len_diff:
if len_diff > 0: if len_diff > 0:
dir_with_more = "Left" dir_with_more = 'Left'
extra = saferepr(left[len_right]) extra = saferepr(left[len_right])
else: else:
len_diff = 0 - len_diff len_diff = 0 - len_diff
dir_with_more = "Right" dir_with_more = 'Right'
extra = saferepr(right[len_left]) extra = saferepr(right[len_left])
if len_diff == 1: if len_diff == 1:
explanation += [ explanation += [
f"{dir_with_more} contains one more item: {highlighter(extra)}" f'{dir_with_more} contains one more item: {highlighter(extra)}',
] ]
else: else:
explanation += [ explanation += [
"%s contains %d more items, first extra item: %s" '%s contains %d more items, first extra item: %s'
% (dir_with_more, len_diff, highlighter(extra)) % (dir_with_more, len_diff, highlighter(extra)),
] ]
return explanation return explanation
@ -417,10 +419,10 @@ def _compare_eq_set(
right: AbstractSet[Any], right: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
explanation = [] explanation = []
explanation.extend(_set_one_sided_diff("left", left, right, highlighter)) explanation.extend(_set_one_sided_diff('left', left, right, highlighter))
explanation.extend(_set_one_sided_diff("right", right, left, highlighter)) explanation.extend(_set_one_sided_diff('right', right, left, highlighter))
return explanation return explanation
@ -429,10 +431,10 @@ def _compare_gt_set(
right: AbstractSet[Any], right: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
explanation = _compare_gte_set(left, right, highlighter) explanation = _compare_gte_set(left, right, highlighter)
if not explanation: if not explanation:
return ["Both sets are equal"] return ['Both sets are equal']
return explanation return explanation
@ -441,10 +443,10 @@ def _compare_lt_set(
right: AbstractSet[Any], right: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
explanation = _compare_lte_set(left, right, highlighter) explanation = _compare_lte_set(left, right, highlighter)
if not explanation: if not explanation:
return ["Both sets are equal"] return ['Both sets are equal']
return explanation return explanation
@ -453,8 +455,8 @@ def _compare_gte_set(
right: AbstractSet[Any], right: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
return _set_one_sided_diff("right", right, left, highlighter) return _set_one_sided_diff('right', right, left, highlighter)
def _compare_lte_set( def _compare_lte_set(
@ -462,8 +464,8 @@ def _compare_lte_set(
right: AbstractSet[Any], right: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
return _set_one_sided_diff("left", left, right, highlighter) return _set_one_sided_diff('left', left, right, highlighter)
def _set_one_sided_diff( def _set_one_sided_diff(
@ -471,11 +473,11 @@ def _set_one_sided_diff(
set1: AbstractSet[Any], set1: AbstractSet[Any],
set2: AbstractSet[Any], set2: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
) -> List[str]: ) -> list[str]:
explanation = [] explanation = []
diff = set1 - set2 diff = set1 - set2
if diff: if diff:
explanation.append(f"Extra items in the {posn} set:") explanation.append(f'Extra items in the {posn} set:')
for item in diff: for item in diff:
explanation.append(highlighter(saferepr(item))) explanation.append(highlighter(saferepr(item)))
return explanation return explanation
@ -486,52 +488,52 @@ def _compare_eq_dict(
right: Mapping[Any, Any], right: Mapping[Any, Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
explanation: List[str] = [] explanation: list[str] = []
set_left = set(left) set_left = set(left)
set_right = set(right) set_right = set(right)
common = set_left.intersection(set_right) common = set_left.intersection(set_right)
same = {k: left[k] for k in common if left[k] == right[k]} same = {k: left[k] for k in common if left[k] == right[k]}
if same and verbose < 2: if same and verbose < 2:
explanation += ["Omitting %s identical items, use -vv to show" % len(same)] explanation += ['Omitting %s identical items, use -vv to show' % len(same)]
elif same: elif same:
explanation += ["Common items:"] explanation += ['Common items:']
explanation += highlighter(pprint.pformat(same)).splitlines() explanation += highlighter(pprint.pformat(same)).splitlines()
diff = {k for k in common if left[k] != right[k]} diff = {k for k in common if left[k] != right[k]}
if diff: if diff:
explanation += ["Differing items:"] explanation += ['Differing items:']
for k in diff: for k in diff:
explanation += [ explanation += [
highlighter(saferepr({k: left[k]})) highlighter(saferepr({k: left[k]})) +
+ " != " ' != ' +
+ highlighter(saferepr({k: right[k]})) highlighter(saferepr({k: right[k]})),
] ]
extra_left = set_left - set_right extra_left = set_left - set_right
len_extra_left = len(extra_left) len_extra_left = len(extra_left)
if len_extra_left: if len_extra_left:
explanation.append( explanation.append(
"Left contains %d more item%s:" 'Left contains %d more item%s:'
% (len_extra_left, "" if len_extra_left == 1 else "s") % (len_extra_left, '' if len_extra_left == 1 else 's'),
) )
explanation.extend( explanation.extend(
highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines() highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines(),
) )
extra_right = set_right - set_left extra_right = set_right - set_left
len_extra_right = len(extra_right) len_extra_right = len(extra_right)
if len_extra_right: if len_extra_right:
explanation.append( explanation.append(
"Right contains %d more item%s:" 'Right contains %d more item%s:'
% (len_extra_right, "" if len_extra_right == 1 else "s") % (len_extra_right, '' if len_extra_right == 1 else 's'),
) )
explanation.extend( explanation.extend(
highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines() highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines(),
) )
return explanation return explanation
def _compare_eq_cls( def _compare_eq_cls(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int left: Any, right: Any, highlighter: _HighlightFunc, verbose: int,
) -> List[str]: ) -> list[str]:
if not has_default_eq(left): if not has_default_eq(left):
return [] return []
if isdatacls(left): if isdatacls(left):
@ -541,13 +543,13 @@ def _compare_eq_cls(
fields_to_check = [info.name for info in all_fields if info.compare] fields_to_check = [info.name for info in all_fields if info.compare]
elif isattrs(left): elif isattrs(left):
all_fields = left.__attrs_attrs__ all_fields = left.__attrs_attrs__
fields_to_check = [field.name for field in all_fields if getattr(field, "eq")] fields_to_check = [field.name for field in all_fields if getattr(field, 'eq')]
elif isnamedtuple(left): elif isnamedtuple(left):
fields_to_check = left._fields fields_to_check = left._fields
else: else:
assert False assert False
indent = " " indent = ' '
same = [] same = []
diff = [] diff = []
for field in fields_to_check: for field in fields_to_check:
@ -558,46 +560,46 @@ def _compare_eq_cls(
explanation = [] explanation = []
if same or diff: if same or diff:
explanation += [""] explanation += ['']
if same and verbose < 2: if same and verbose < 2:
explanation.append("Omitting %s identical items, use -vv to show" % len(same)) explanation.append('Omitting %s identical items, use -vv to show' % len(same))
elif same: elif same:
explanation += ["Matching attributes:"] explanation += ['Matching attributes:']
explanation += highlighter(pprint.pformat(same)).splitlines() explanation += highlighter(pprint.pformat(same)).splitlines()
if diff: if diff:
explanation += ["Differing attributes:"] explanation += ['Differing attributes:']
explanation += highlighter(pprint.pformat(diff)).splitlines() explanation += highlighter(pprint.pformat(diff)).splitlines()
for field in diff: for field in diff:
field_left = getattr(left, field) field_left = getattr(left, field)
field_right = getattr(right, field) field_right = getattr(right, field)
explanation += [ explanation += [
"", '',
f"Drill down into differing attribute {field}:", f'Drill down into differing attribute {field}:',
f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}", f'{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}',
] ]
explanation += [ explanation += [
indent + line indent + line
for line in _compare_eq_any( for line in _compare_eq_any(
field_left, field_right, highlighter, verbose field_left, field_right, highlighter, verbose,
) )
] ]
return explanation return explanation
def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]: def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]:
index = text.find(term) index = text.find(term)
head = text[:index] head = text[:index]
tail = text[index + len(term) :] tail = text[index + len(term):]
correct_text = head + tail correct_text = head + tail
diff = _diff_text(text, correct_text, verbose) diff = _diff_text(text, correct_text, verbose)
newdiff = ["%s is contained here:" % saferepr(term, maxsize=42)] newdiff = ['%s is contained here:' % saferepr(term, maxsize=42)]
for line in diff: for line in diff:
if line.startswith("Skipping"): if line.startswith('Skipping'):
continue continue
if line.startswith("- "): if line.startswith('- '):
continue continue
if line.startswith("+ "): if line.startswith('+ '):
newdiff.append(" " + line[2:]) newdiff.append(' ' + line[2:])
else: else:
newdiff.append(line) newdiff.append(line)
return newdiff return newdiff
@ -605,5 +607,5 @@ def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]:
def running_on_ci() -> bool: def running_on_ci() -> bool:
"""Check if we're currently running on a CI system.""" """Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"] env_vars = ['CI', 'BUILD_NUMBER']
return any(var in os.environ for var in env_vars) return any(var in os.environ for var in env_vars)

View file

@ -2,6 +2,8 @@
"""Implementation of the cache provider.""" """Implementation of the cache provider."""
# This plugin was not named "cache" to avoid conflicts with the external # This plugin was not named "cache" to avoid conflicts with the external
# pytest-cache version. # pytest-cache version.
from __future__ import annotations
import dataclasses import dataclasses
import json import json
import os import os
@ -15,9 +17,6 @@ from typing import Optional
from typing import Set from typing import Set
from typing import Union from typing import Union
from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
from _pytest import nodes from _pytest import nodes
from _pytest._io import TerminalWriter from _pytest._io import TerminalWriter
from _pytest.config import Config from _pytest.config import Config
@ -32,6 +31,10 @@ from _pytest.nodes import Directory
from _pytest.nodes import File from _pytest.nodes import File
from _pytest.reports import TestReport from _pytest.reports import TestReport
from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
README_CONTENT = """\ README_CONTENT = """\
# pytest cache directory # # pytest cache directory #
@ -61,27 +64,27 @@ class Cache:
_config: Config = dataclasses.field(repr=False) _config: Config = dataclasses.field(repr=False)
# Sub-directory under cache-dir for directories created by `mkdir()`. # Sub-directory under cache-dir for directories created by `mkdir()`.
_CACHE_PREFIX_DIRS = "d" _CACHE_PREFIX_DIRS = 'd'
# Sub-directory under cache-dir for values created by `set()`. # Sub-directory under cache-dir for values created by `set()`.
_CACHE_PREFIX_VALUES = "v" _CACHE_PREFIX_VALUES = 'v'
def __init__( def __init__(
self, cachedir: Path, config: Config, *, _ispytest: bool = False self, cachedir: Path, config: Config, *, _ispytest: bool = False,
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self._cachedir = cachedir self._cachedir = cachedir
self._config = config self._config = config
@classmethod @classmethod
def for_config(cls, config: Config, *, _ispytest: bool = False) -> "Cache": def for_config(cls, config: Config, *, _ispytest: bool = False) -> Cache:
"""Create the Cache instance for a Config. """Create the Cache instance for a Config.
:meta private: :meta private:
""" """
check_ispytest(_ispytest) check_ispytest(_ispytest)
cachedir = cls.cache_dir_from_config(config, _ispytest=True) cachedir = cls.cache_dir_from_config(config, _ispytest=True)
if config.getoption("cacheclear") and cachedir.is_dir(): if config.getoption('cacheclear') and cachedir.is_dir():
cls.clear_cache(cachedir, _ispytest=True) cls.clear_cache(cachedir, _ispytest=True)
return cls(cachedir, config, _ispytest=True) return cls(cachedir, config, _ispytest=True)
@ -104,7 +107,7 @@ class Cache:
:meta private: :meta private:
""" """
check_ispytest(_ispytest) check_ispytest(_ispytest)
return resolve_from_str(config.getini("cache_dir"), config.rootpath) return resolve_from_str(config.getini('cache_dir'), config.rootpath)
def warn(self, fmt: str, *, _ispytest: bool = False, **args: object) -> None: def warn(self, fmt: str, *, _ispytest: bool = False, **args: object) -> None:
"""Issue a cache warning. """Issue a cache warning.
@ -138,7 +141,7 @@ class Cache:
""" """
path = Path(name) path = Path(name)
if len(path.parts) > 1: if len(path.parts) > 1:
raise ValueError("name is not allowed to contain path separators") raise ValueError('name is not allowed to contain path separators')
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path) res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
res.mkdir(exist_ok=True, parents=True) res.mkdir(exist_ok=True, parents=True)
return res return res
@ -160,7 +163,7 @@ class Cache:
""" """
path = self._getvaluepath(key) path = self._getvaluepath(key)
try: try:
with path.open("r", encoding="UTF-8") as f: with path.open('r', encoding='UTF-8') as f:
return json.load(f) return json.load(f)
except (ValueError, OSError): except (ValueError, OSError):
return default return default
@ -184,7 +187,7 @@ class Cache:
path.parent.mkdir(exist_ok=True, parents=True) path.parent.mkdir(exist_ok=True, parents=True)
except OSError as exc: except OSError as exc:
self.warn( self.warn(
f"could not create cache path {path}: {exc}", f'could not create cache path {path}: {exc}',
_ispytest=True, _ispytest=True,
) )
return return
@ -192,10 +195,10 @@ class Cache:
self._ensure_supporting_files() self._ensure_supporting_files()
data = json.dumps(value, ensure_ascii=False, indent=2) data = json.dumps(value, ensure_ascii=False, indent=2)
try: try:
f = path.open("w", encoding="UTF-8") f = path.open('w', encoding='UTF-8')
except OSError as exc: except OSError as exc:
self.warn( self.warn(
f"cache could not write path {path}: {exc}", f'cache could not write path {path}: {exc}',
_ispytest=True, _ispytest=True,
) )
else: else:
@ -204,25 +207,25 @@ class Cache:
def _ensure_supporting_files(self) -> None: def _ensure_supporting_files(self) -> None:
"""Create supporting files in the cache dir that are not really part of the cache.""" """Create supporting files in the cache dir that are not really part of the cache."""
readme_path = self._cachedir / "README.md" readme_path = self._cachedir / 'README.md'
readme_path.write_text(README_CONTENT, encoding="UTF-8") readme_path.write_text(README_CONTENT, encoding='UTF-8')
gitignore_path = self._cachedir.joinpath(".gitignore") gitignore_path = self._cachedir.joinpath('.gitignore')
msg = "# Created by pytest automatically.\n*\n" msg = '# Created by pytest automatically.\n*\n'
gitignore_path.write_text(msg, encoding="UTF-8") gitignore_path.write_text(msg, encoding='UTF-8')
cachedir_tag_path = self._cachedir.joinpath("CACHEDIR.TAG") cachedir_tag_path = self._cachedir.joinpath('CACHEDIR.TAG')
cachedir_tag_path.write_bytes(CACHEDIR_TAG_CONTENT) cachedir_tag_path.write_bytes(CACHEDIR_TAG_CONTENT)
class LFPluginCollWrapper: class LFPluginCollWrapper:
def __init__(self, lfplugin: "LFPlugin") -> None: def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin self.lfplugin = lfplugin
self._collected_at_least_one_failure = False self._collected_at_least_one_failure = False
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_make_collect_report( def pytest_make_collect_report(
self, collector: nodes.Collector self, collector: nodes.Collector,
) -> Generator[None, CollectReport, CollectReport]: ) -> Generator[None, CollectReport, CollectReport]:
res = yield res = yield
if isinstance(collector, (Session, Directory)): if isinstance(collector, (Session, Directory)):
@ -230,7 +233,7 @@ class LFPluginCollWrapper:
lf_paths = self.lfplugin._last_failed_paths lf_paths = self.lfplugin._last_failed_paths
# Use stable sort to priorize last failed. # Use stable sort to priorize last failed.
def sort_key(node: Union[nodes.Item, nodes.Collector]) -> bool: def sort_key(node: nodes.Item | nodes.Collector) -> bool:
return node.path in lf_paths return node.path in lf_paths
res.result = sorted( res.result = sorted(
@ -249,7 +252,7 @@ class LFPluginCollWrapper:
if not any(x.nodeid in lastfailed for x in result): if not any(x.nodeid in lastfailed for x in result):
return res return res
self.lfplugin.config.pluginmanager.register( self.lfplugin.config.pluginmanager.register(
LFPluginCollSkipfiles(self.lfplugin), "lfplugin-collskip" LFPluginCollSkipfiles(self.lfplugin), 'lfplugin-collskip',
) )
self._collected_at_least_one_failure = True self._collected_at_least_one_failure = True
@ -257,30 +260,30 @@ class LFPluginCollWrapper:
result[:] = [ result[:] = [
x x
for x in result for x in result
if x.nodeid in lastfailed if x.nodeid in lastfailed or
# Include any passed arguments (not trivial to filter). # Include any passed arguments (not trivial to filter).
or session.isinitpath(x.path) session.isinitpath(x.path) or
# Keep all sub-collectors. # Keep all sub-collectors.
or isinstance(x, nodes.Collector) isinstance(x, nodes.Collector)
] ]
return res return res
class LFPluginCollSkipfiles: class LFPluginCollSkipfiles:
def __init__(self, lfplugin: "LFPlugin") -> None: def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin self.lfplugin = lfplugin
@hookimpl @hookimpl
def pytest_make_collect_report( def pytest_make_collect_report(
self, collector: nodes.Collector self, collector: nodes.Collector,
) -> Optional[CollectReport]: ) -> CollectReport | None:
if isinstance(collector, File): if isinstance(collector, File):
if collector.path not in self.lfplugin._last_failed_paths: if collector.path not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1 self.lfplugin._skipped_files += 1
return CollectReport( return CollectReport(
collector.nodeid, "passed", longrepr=None, result=[] collector.nodeid, 'passed', longrepr=None, result=[],
) )
return None return None
@ -290,44 +293,44 @@ class LFPlugin:
def __init__(self, config: Config) -> None: def __init__(self, config: Config) -> None:
self.config = config self.config = config
active_keys = "lf", "failedfirst" active_keys = 'lf', 'failedfirst'
self.active = any(config.getoption(key) for key in active_keys) self.active = any(config.getoption(key) for key in active_keys)
assert config.cache assert config.cache
self.lastfailed: Dict[str, bool] = config.cache.get("cache/lastfailed", {}) self.lastfailed: dict[str, bool] = config.cache.get('cache/lastfailed', {})
self._previously_failed_count: Optional[int] = None self._previously_failed_count: int | None = None
self._report_status: Optional[str] = None self._report_status: str | None = None
self._skipped_files = 0 # count skipped files during collection due to --lf self._skipped_files = 0 # count skipped files during collection due to --lf
if config.getoption("lf"): if config.getoption('lf'):
self._last_failed_paths = self.get_last_failed_paths() self._last_failed_paths = self.get_last_failed_paths()
config.pluginmanager.register( config.pluginmanager.register(
LFPluginCollWrapper(self), "lfplugin-collwrapper" LFPluginCollWrapper(self), 'lfplugin-collwrapper',
) )
def get_last_failed_paths(self) -> Set[Path]: def get_last_failed_paths(self) -> set[Path]:
"""Return a set with all Paths of the previously failed nodeids and """Return a set with all Paths of the previously failed nodeids and
their parents.""" their parents."""
rootpath = self.config.rootpath rootpath = self.config.rootpath
result = set() result = set()
for nodeid in self.lastfailed: for nodeid in self.lastfailed:
path = rootpath / nodeid.split("::")[0] path = rootpath / nodeid.split('::')[0]
result.add(path) result.add(path)
result.update(path.parents) result.update(path.parents)
return {x for x in result if x.exists()} return {x for x in result if x.exists()}
def pytest_report_collectionfinish(self) -> Optional[str]: def pytest_report_collectionfinish(self) -> str | None:
if self.active and self.config.getoption("verbose") >= 0: if self.active and self.config.getoption('verbose') >= 0:
return "run-last-failure: %s" % self._report_status return 'run-last-failure: %s' % self._report_status
return None return None
def pytest_runtest_logreport(self, report: TestReport) -> None: def pytest_runtest_logreport(self, report: TestReport) -> None:
if (report.when == "call" and report.passed) or report.skipped: if (report.when == 'call' and report.passed) or report.skipped:
self.lastfailed.pop(report.nodeid, None) self.lastfailed.pop(report.nodeid, None)
elif report.failed: elif report.failed:
self.lastfailed[report.nodeid] = True self.lastfailed[report.nodeid] = True
def pytest_collectreport(self, report: CollectReport) -> None: def pytest_collectreport(self, report: CollectReport) -> None:
passed = report.outcome in ("passed", "skipped") passed = report.outcome in ('passed', 'skipped')
if passed: if passed:
if report.nodeid in self.lastfailed: if report.nodeid in self.lastfailed:
self.lastfailed.pop(report.nodeid) self.lastfailed.pop(report.nodeid)
@ -337,7 +340,7 @@ class LFPlugin:
@hookimpl(wrapper=True, tryfirst=True) @hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems( def pytest_collection_modifyitems(
self, config: Config, items: List[nodes.Item] self, config: Config, items: list[nodes.Item],
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
res = yield res = yield
@ -357,45 +360,45 @@ class LFPlugin:
if not previously_failed: if not previously_failed:
# Running a subset of all tests with recorded failures # Running a subset of all tests with recorded failures
# only outside of it. # only outside of it.
self._report_status = "%d known failures not in selected tests" % ( self._report_status = '%d known failures not in selected tests' % (
len(self.lastfailed), len(self.lastfailed),
) )
else: else:
if self.config.getoption("lf"): if self.config.getoption('lf'):
items[:] = previously_failed items[:] = previously_failed
config.hook.pytest_deselected(items=previously_passed) config.hook.pytest_deselected(items=previously_passed)
else: # --failedfirst else: # --failedfirst
items[:] = previously_failed + previously_passed items[:] = previously_failed + previously_passed
noun = "failure" if self._previously_failed_count == 1 else "failures" noun = 'failure' if self._previously_failed_count == 1 else 'failures'
suffix = " first" if self.config.getoption("failedfirst") else "" suffix = ' first' if self.config.getoption('failedfirst') else ''
self._report_status = ( self._report_status = (
f"rerun previous {self._previously_failed_count} {noun}{suffix}" f'rerun previous {self._previously_failed_count} {noun}{suffix}'
) )
if self._skipped_files > 0: if self._skipped_files > 0:
files_noun = "file" if self._skipped_files == 1 else "files" files_noun = 'file' if self._skipped_files == 1 else 'files'
self._report_status += f" (skipped {self._skipped_files} {files_noun})" self._report_status += f' (skipped {self._skipped_files} {files_noun})'
else: else:
self._report_status = "no previously failed tests, " self._report_status = 'no previously failed tests, '
if self.config.getoption("last_failed_no_failures") == "none": if self.config.getoption('last_failed_no_failures') == 'none':
self._report_status += "deselecting all items." self._report_status += 'deselecting all items.'
config.hook.pytest_deselected(items=items[:]) config.hook.pytest_deselected(items=items[:])
items[:] = [] items[:] = []
else: else:
self._report_status += "not deselecting items." self._report_status += 'not deselecting items.'
return res return res
def pytest_sessionfinish(self, session: Session) -> None: def pytest_sessionfinish(self, session: Session) -> None:
config = self.config config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"): if config.getoption('cacheshow') or hasattr(config, 'workerinput'):
return return
assert config.cache is not None assert config.cache is not None
saved_lastfailed = config.cache.get("cache/lastfailed", {}) saved_lastfailed = config.cache.get('cache/lastfailed', {})
if saved_lastfailed != self.lastfailed: if saved_lastfailed != self.lastfailed:
config.cache.set("cache/lastfailed", self.lastfailed) config.cache.set('cache/lastfailed', self.lastfailed)
class NFPlugin: class NFPlugin:
@ -405,17 +408,17 @@ class NFPlugin:
self.config = config self.config = config
self.active = config.option.newfirst self.active = config.option.newfirst
assert config.cache is not None assert config.cache is not None
self.cached_nodeids = set(config.cache.get("cache/nodeids", [])) self.cached_nodeids = set(config.cache.get('cache/nodeids', []))
@hookimpl(wrapper=True, tryfirst=True) @hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems( def pytest_collection_modifyitems(
self, items: List[nodes.Item] self, items: list[nodes.Item],
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
res = yield res = yield
if self.active: if self.active:
new_items: Dict[str, nodes.Item] = {} new_items: dict[str, nodes.Item] = {}
other_items: Dict[str, nodes.Item] = {} other_items: dict[str, nodes.Item] = {}
for item in items: for item in items:
if item.nodeid not in self.cached_nodeids: if item.nodeid not in self.cached_nodeids:
new_items[item.nodeid] = item new_items[item.nodeid] = item
@ -423,7 +426,7 @@ class NFPlugin:
other_items[item.nodeid] = item other_items[item.nodeid] = item
items[:] = self._get_increasing_order( items[:] = self._get_increasing_order(
new_items.values() new_items.values(),
) + self._get_increasing_order(other_items.values()) ) + self._get_increasing_order(other_items.values())
self.cached_nodeids.update(new_items) self.cached_nodeids.update(new_items)
else: else:
@ -431,84 +434,84 @@ class NFPlugin:
return res return res
def _get_increasing_order(self, items: Iterable[nodes.Item]) -> List[nodes.Item]: def _get_increasing_order(self, items: Iterable[nodes.Item]) -> list[nodes.Item]:
return sorted(items, key=lambda item: item.path.stat().st_mtime, reverse=True) # type: ignore[no-any-return] return sorted(items, key=lambda item: item.path.stat().st_mtime, reverse=True) # type: ignore[no-any-return]
def pytest_sessionfinish(self) -> None: def pytest_sessionfinish(self) -> None:
config = self.config config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"): if config.getoption('cacheshow') or hasattr(config, 'workerinput'):
return return
if config.getoption("collectonly"): if config.getoption('collectonly'):
return return
assert config.cache is not None assert config.cache is not None
config.cache.set("cache/nodeids", sorted(self.cached_nodeids)) config.cache.set('cache/nodeids', sorted(self.cached_nodeids))
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general") group = parser.getgroup('general')
group.addoption( group.addoption(
"--lf", '--lf',
"--last-failed", '--last-failed',
action="store_true", action='store_true',
dest="lf", dest='lf',
help="Rerun only the tests that failed " help='Rerun only the tests that failed '
"at the last run (or all if none failed)", 'at the last run (or all if none failed)',
) )
group.addoption( group.addoption(
"--ff", '--ff',
"--failed-first", '--failed-first',
action="store_true", action='store_true',
dest="failedfirst", dest='failedfirst',
help="Run all tests, but run the last failures first. " help='Run all tests, but run the last failures first. '
"This may re-order tests and thus lead to " 'This may re-order tests and thus lead to '
"repeated fixture setup/teardown.", 'repeated fixture setup/teardown.',
) )
group.addoption( group.addoption(
"--nf", '--nf',
"--new-first", '--new-first',
action="store_true", action='store_true',
dest="newfirst", dest='newfirst',
help="Run tests from new files first, then the rest of the tests " help='Run tests from new files first, then the rest of the tests '
"sorted by file mtime", 'sorted by file mtime',
) )
group.addoption( group.addoption(
"--cache-show", '--cache-show',
action="append", action='append',
nargs="?", nargs='?',
dest="cacheshow", dest='cacheshow',
help=( help=(
"Show cache contents, don't perform collection or tests. " "Show cache contents, don't perform collection or tests. "
"Optional argument: glob (default: '*')." "Optional argument: glob (default: '*')."
), ),
) )
group.addoption( group.addoption(
"--cache-clear", '--cache-clear',
action="store_true", action='store_true',
dest="cacheclear", dest='cacheclear',
help="Remove all cache contents at start of test run", help='Remove all cache contents at start of test run',
) )
cache_dir_default = ".pytest_cache" cache_dir_default = '.pytest_cache'
if "TOX_ENV_DIR" in os.environ: if 'TOX_ENV_DIR' in os.environ:
cache_dir_default = os.path.join(os.environ["TOX_ENV_DIR"], cache_dir_default) cache_dir_default = os.path.join(os.environ['TOX_ENV_DIR'], cache_dir_default)
parser.addini("cache_dir", default=cache_dir_default, help="Cache directory path") parser.addini('cache_dir', default=cache_dir_default, help='Cache directory path')
group.addoption( group.addoption(
"--lfnf", '--lfnf',
"--last-failed-no-failures", '--last-failed-no-failures',
action="store", action='store',
dest="last_failed_no_failures", dest='last_failed_no_failures',
choices=("all", "none"), choices=('all', 'none'),
default="all", default='all',
help="With ``--lf``, determines whether to execute tests when there " help='With ``--lf``, determines whether to execute tests when there '
"are no previously (known) failures or when no " 'are no previously (known) failures or when no '
"cached ``lastfailed`` data was found. " 'cached ``lastfailed`` data was found. '
"``all`` (the default) runs the full test suite again. " '``all`` (the default) runs the full test suite again. '
"``none`` just emits a message about no known failures and exits successfully.", '``none`` just emits a message about no known failures and exits successfully.',
) )
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.cacheshow and not config.option.help: if config.option.cacheshow and not config.option.help:
from _pytest.main import wrap_session from _pytest.main import wrap_session
@ -519,8 +522,8 @@ def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
@hookimpl(tryfirst=True) @hookimpl(tryfirst=True)
def pytest_configure(config: Config) -> None: def pytest_configure(config: Config) -> None:
config.cache = Cache.for_config(config, _ispytest=True) config.cache = Cache.for_config(config, _ispytest=True)
config.pluginmanager.register(LFPlugin(config), "lfplugin") config.pluginmanager.register(LFPlugin(config), 'lfplugin')
config.pluginmanager.register(NFPlugin(config), "nfplugin") config.pluginmanager.register(NFPlugin(config), 'nfplugin')
@fixture @fixture
@ -539,9 +542,9 @@ def cache(request: FixtureRequest) -> Cache:
return request.config.cache return request.config.cache
def pytest_report_header(config: Config) -> Optional[str]: def pytest_report_header(config: Config) -> str | None:
"""Display cachedir with --cache-show and if non-default.""" """Display cachedir with --cache-show and if non-default."""
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache": if config.option.verbose > 0 or config.getini('cache_dir') != '.pytest_cache':
assert config.cache is not None assert config.cache is not None
cachedir = config.cache._cachedir cachedir = config.cache._cachedir
# TODO: evaluate generating upward relative paths # TODO: evaluate generating upward relative paths
@ -551,7 +554,7 @@ def pytest_report_header(config: Config) -> Optional[str]:
displaypath = cachedir.relative_to(config.rootpath) displaypath = cachedir.relative_to(config.rootpath)
except ValueError: except ValueError:
displaypath = cachedir displaypath = cachedir
return f"cachedir: {displaypath}" return f'cachedir: {displaypath}'
return None return None
@ -561,37 +564,37 @@ def cacheshow(config: Config, session: Session) -> int:
assert config.cache is not None assert config.cache is not None
tw = TerminalWriter() tw = TerminalWriter()
tw.line("cachedir: " + str(config.cache._cachedir)) tw.line('cachedir: ' + str(config.cache._cachedir))
if not config.cache._cachedir.is_dir(): if not config.cache._cachedir.is_dir():
tw.line("cache is empty") tw.line('cache is empty')
return 0 return 0
glob = config.option.cacheshow[0] glob = config.option.cacheshow[0]
if glob is None: if glob is None:
glob = "*" glob = '*'
dummy = object() dummy = object()
basedir = config.cache._cachedir basedir = config.cache._cachedir
vdir = basedir / Cache._CACHE_PREFIX_VALUES vdir = basedir / Cache._CACHE_PREFIX_VALUES
tw.sep("-", "cache values for %r" % glob) tw.sep('-', 'cache values for %r' % glob)
for valpath in sorted(x for x in vdir.rglob(glob) if x.is_file()): for valpath in sorted(x for x in vdir.rglob(glob) if x.is_file()):
key = str(valpath.relative_to(vdir)) key = str(valpath.relative_to(vdir))
val = config.cache.get(key, dummy) val = config.cache.get(key, dummy)
if val is dummy: if val is dummy:
tw.line("%s contains unreadable content, will be ignored" % key) tw.line('%s contains unreadable content, will be ignored' % key)
else: else:
tw.line("%s contains:" % key) tw.line('%s contains:' % key)
for line in pformat(val).splitlines(): for line in pformat(val).splitlines():
tw.line(" " + line) tw.line(' ' + line)
ddir = basedir / Cache._CACHE_PREFIX_DIRS ddir = basedir / Cache._CACHE_PREFIX_DIRS
if ddir.is_dir(): if ddir.is_dir():
contents = sorted(ddir.rglob(glob)) contents = sorted(ddir.rglob(glob))
tw.sep("-", "cache directories for %r" % glob) tw.sep('-', 'cache directories for %r' % glob)
for p in contents: for p in contents:
# if p.is_dir(): # if p.is_dir():
# print("%s/" % p.relative_to(basedir)) # print("%s/" % p.relative_to(basedir))
if p.is_file(): if p.is_file():
key = str(p.relative_to(basedir)) key = str(p.relative_to(basedir))
tw.line(f"{key} is a file of length {p.stat().st_size:d}") tw.line(f'{key} is a file of length {p.stat().st_size:d}')
return 0 return 0

View file

@ -1,12 +1,14 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Per-test stdout/stderr capturing mechanism.""" """Per-test stdout/stderr capturing mechanism."""
from __future__ import annotations
import abc import abc
import collections import collections
import contextlib import contextlib
import io import io
from io import UnsupportedOperation
import os import os
import sys import sys
from io import UnsupportedOperation
from tempfile import TemporaryFile from tempfile import TemporaryFile
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
@ -40,25 +42,25 @@ from _pytest.nodes import Item
from _pytest.reports import CollectReport from _pytest.reports import CollectReport
_CaptureMethod = Literal["fd", "sys", "no", "tee-sys"] _CaptureMethod = Literal['fd', 'sys', 'no', 'tee-sys']
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general") group = parser.getgroup('general')
group._addoption( group._addoption(
"--capture", '--capture',
action="store", action='store',
default="fd", default='fd',
metavar="method", metavar='method',
choices=["fd", "sys", "no", "tee-sys"], choices=['fd', 'sys', 'no', 'tee-sys'],
help="Per-test capturing method: one of fd|sys|no|tee-sys", help='Per-test capturing method: one of fd|sys|no|tee-sys',
) )
group._addoption( group._addoption(
"-s", '-s',
action="store_const", action='store_const',
const="no", const='no',
dest="capture", dest='capture',
help="Shortcut for --capture=no", help='Shortcut for --capture=no',
) )
@ -70,7 +72,7 @@ def _colorama_workaround() -> None:
first import of colorama while I/O capture is active, colorama will first import of colorama while I/O capture is active, colorama will
fail in various ways. fail in various ways.
""" """
if sys.platform.startswith("win32"): if sys.platform.startswith('win32'):
try: try:
import colorama # noqa: F401 import colorama # noqa: F401
except ImportError: except ImportError:
@ -101,21 +103,21 @@ def _windowsconsoleio_workaround(stream: TextIO) -> None:
See https://github.com/pytest-dev/py/issues/103. See https://github.com/pytest-dev/py/issues/103.
""" """
if not sys.platform.startswith("win32") or hasattr(sys, "pypy_version_info"): if not sys.platform.startswith('win32') or hasattr(sys, 'pypy_version_info'):
return return
# Bail out if ``stream`` doesn't seem like a proper ``io`` stream (#2666). # Bail out if ``stream`` doesn't seem like a proper ``io`` stream (#2666).
if not hasattr(stream, "buffer"): # type: ignore[unreachable] if not hasattr(stream, 'buffer'): # type: ignore[unreachable]
return return
buffered = hasattr(stream.buffer, "raw") buffered = hasattr(stream.buffer, 'raw')
raw_stdout = stream.buffer.raw if buffered else stream.buffer # type: ignore[attr-defined] raw_stdout = stream.buffer.raw if buffered else stream.buffer # type: ignore[attr-defined]
if not isinstance(raw_stdout, io._WindowsConsoleIO): # type: ignore[attr-defined] if not isinstance(raw_stdout, io._WindowsConsoleIO): # type: ignore[attr-defined]
return return
def _reopen_stdio(f, mode): def _reopen_stdio(f, mode):
if not buffered and mode[0] == "w": if not buffered and mode[0] == 'w':
buffering = 0 buffering = 0
else: else:
buffering = -1 buffering = -1
@ -128,20 +130,20 @@ def _windowsconsoleio_workaround(stream: TextIO) -> None:
f.line_buffering, f.line_buffering,
) )
sys.stdin = _reopen_stdio(sys.stdin, "rb") sys.stdin = _reopen_stdio(sys.stdin, 'rb')
sys.stdout = _reopen_stdio(sys.stdout, "wb") sys.stdout = _reopen_stdio(sys.stdout, 'wb')
sys.stderr = _reopen_stdio(sys.stderr, "wb") sys.stderr = _reopen_stdio(sys.stderr, 'wb')
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_load_initial_conftests(early_config: Config) -> Generator[None, None, None]: def pytest_load_initial_conftests(early_config: Config) -> Generator[None, None, None]:
ns = early_config.known_args_namespace ns = early_config.known_args_namespace
if ns.capture == "fd": if ns.capture == 'fd':
_windowsconsoleio_workaround(sys.stdout) _windowsconsoleio_workaround(sys.stdout)
_colorama_workaround() _colorama_workaround()
pluginmanager = early_config.pluginmanager pluginmanager = early_config.pluginmanager
capman = CaptureManager(ns.capture) capman = CaptureManager(ns.capture)
pluginmanager.register(capman, "capturemanager") pluginmanager.register(capman, 'capturemanager')
# Make sure that capturemanager is properly reset at final shutdown. # Make sure that capturemanager is properly reset at final shutdown.
early_config.add_cleanup(capman.stop_global_capturing) early_config.add_cleanup(capman.stop_global_capturing)
@ -176,16 +178,16 @@ class EncodedFile(io.TextIOWrapper):
def mode(self) -> str: def mode(self) -> str:
# TextIOWrapper doesn't expose a mode, but at least some of our # TextIOWrapper doesn't expose a mode, but at least some of our
# tests check it. # tests check it.
return self.buffer.mode.replace("b", "") return self.buffer.mode.replace('b', '')
class CaptureIO(io.TextIOWrapper): class CaptureIO(io.TextIOWrapper):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True) super().__init__(io.BytesIO(), encoding='UTF-8', newline='', write_through=True)
def getvalue(self) -> str: def getvalue(self) -> str:
assert isinstance(self.buffer, io.BytesIO) assert isinstance(self.buffer, io.BytesIO)
return self.buffer.getvalue().decode("UTF-8") return self.buffer.getvalue().decode('UTF-8')
class TeeCaptureIO(CaptureIO): class TeeCaptureIO(CaptureIO):
@ -205,7 +207,7 @@ class DontReadFromInput(TextIO):
def read(self, size: int = -1) -> str: def read(self, size: int = -1) -> str:
raise OSError( raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`." 'pytest: reading from stdin while output is captured! Consider using `-s`.',
) )
readline = read readline = read
@ -213,19 +215,19 @@ class DontReadFromInput(TextIO):
def __next__(self) -> str: def __next__(self) -> str:
return self.readline() return self.readline()
def readlines(self, hint: Optional[int] = -1) -> List[str]: def readlines(self, hint: int | None = -1) -> list[str]:
raise OSError( raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`." 'pytest: reading from stdin while output is captured! Consider using `-s`.',
) )
def __iter__(self) -> Iterator[str]: def __iter__(self) -> Iterator[str]:
return self return self
def fileno(self) -> int: def fileno(self) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no fileno()") raise UnsupportedOperation('redirected stdin is pseudofile, has no fileno()')
def flush(self) -> None: def flush(self) -> None:
raise UnsupportedOperation("redirected stdin is pseudofile, has no flush()") raise UnsupportedOperation('redirected stdin is pseudofile, has no flush()')
def isatty(self) -> bool: def isatty(self) -> bool:
return False return False
@ -237,34 +239,34 @@ class DontReadFromInput(TextIO):
return False return False
def seek(self, offset: int, whence: int = 0) -> int: def seek(self, offset: int, whence: int = 0) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no seek(int)") raise UnsupportedOperation('redirected stdin is pseudofile, has no seek(int)')
def seekable(self) -> bool: def seekable(self) -> bool:
return False return False
def tell(self) -> int: def tell(self) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()") raise UnsupportedOperation('redirected stdin is pseudofile, has no tell()')
def truncate(self, size: Optional[int] = None) -> int: def truncate(self, size: int | None = None) -> int:
raise UnsupportedOperation("cannot truncate stdin") raise UnsupportedOperation('cannot truncate stdin')
def write(self, data: str) -> int: def write(self, data: str) -> int:
raise UnsupportedOperation("cannot write to stdin") raise UnsupportedOperation('cannot write to stdin')
def writelines(self, lines: Iterable[str]) -> None: def writelines(self, lines: Iterable[str]) -> None:
raise UnsupportedOperation("Cannot write to stdin") raise UnsupportedOperation('Cannot write to stdin')
def writable(self) -> bool: def writable(self) -> bool:
return False return False
def __enter__(self) -> "DontReadFromInput": def __enter__(self) -> DontReadFromInput:
return self return self
def __exit__( def __exit__(
self, self,
type: Optional[Type[BaseException]], type: type[BaseException] | None,
value: Optional[BaseException], value: BaseException | None,
traceback: Optional[TracebackType], traceback: TracebackType | None,
) -> None: ) -> None:
pass pass
@ -309,11 +311,11 @@ class CaptureBase(abc.ABC, Generic[AnyStr]):
raise NotImplementedError() raise NotImplementedError()
patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"} patchsysdict = {0: 'stdin', 1: 'stdout', 2: 'stderr'}
class NoCapture(CaptureBase[str]): class NoCapture(CaptureBase[str]):
EMPTY_BUFFER = "" EMPTY_BUFFER = ''
def __init__(self, fd: int) -> None: def __init__(self, fd: int) -> None:
pass pass
@ -331,7 +333,7 @@ class NoCapture(CaptureBase[str]):
pass pass
def snap(self) -> str: def snap(self) -> str:
return "" return ''
def writeorg(self, data: str) -> None: def writeorg(self, data: str) -> None:
pass pass
@ -339,76 +341,76 @@ class NoCapture(CaptureBase[str]):
class SysCaptureBase(CaptureBase[AnyStr]): class SysCaptureBase(CaptureBase[AnyStr]):
def __init__( def __init__(
self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False self, fd: int, tmpfile: TextIO | None = None, *, tee: bool = False,
) -> None: ) -> None:
name = patchsysdict[fd] name = patchsysdict[fd]
self._old: TextIO = getattr(sys, name) self._old: TextIO = getattr(sys, name)
self.name = name self.name = name
if tmpfile is None: if tmpfile is None:
if name == "stdin": if name == 'stdin':
tmpfile = DontReadFromInput() tmpfile = DontReadFromInput()
else: else:
tmpfile = CaptureIO() if not tee else TeeCaptureIO(self._old) tmpfile = CaptureIO() if not tee else TeeCaptureIO(self._old)
self.tmpfile = tmpfile self.tmpfile = tmpfile
self._state = "initialized" self._state = 'initialized'
def repr(self, class_name: str) -> str: def repr(self, class_name: str) -> str:
return "<{} {} _old={} _state={!r} tmpfile={!r}>".format( return '<{} {} _old={} _state={!r} tmpfile={!r}>'.format(
class_name, class_name,
self.name, self.name,
hasattr(self, "_old") and repr(self._old) or "<UNSET>", hasattr(self, '_old') and repr(self._old) or '<UNSET>',
self._state, self._state,
self.tmpfile, self.tmpfile,
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return "<{} {} _old={} _state={!r} tmpfile={!r}>".format( return '<{} {} _old={} _state={!r} tmpfile={!r}>'.format(
self.__class__.__name__, self.__class__.__name__,
self.name, self.name,
hasattr(self, "_old") and repr(self._old) or "<UNSET>", hasattr(self, '_old') and repr(self._old) or '<UNSET>',
self._state, self._state,
self.tmpfile, self.tmpfile,
) )
def _assert_state(self, op: str, states: Tuple[str, ...]) -> None: def _assert_state(self, op: str, states: tuple[str, ...]) -> None:
assert ( assert (
self._state in states self._state in states
), "cannot {} in state {!r}: expected one of {}".format( ), 'cannot {} in state {!r}: expected one of {}'.format(
op, self._state, ", ".join(states) op, self._state, ', '.join(states),
) )
def start(self) -> None: def start(self) -> None:
self._assert_state("start", ("initialized",)) self._assert_state('start', ('initialized',))
setattr(sys, self.name, self.tmpfile) setattr(sys, self.name, self.tmpfile)
self._state = "started" self._state = 'started'
def done(self) -> None: def done(self) -> None:
self._assert_state("done", ("initialized", "started", "suspended", "done")) self._assert_state('done', ('initialized', 'started', 'suspended', 'done'))
if self._state == "done": if self._state == 'done':
return return
setattr(sys, self.name, self._old) setattr(sys, self.name, self._old)
del self._old del self._old
self.tmpfile.close() self.tmpfile.close()
self._state = "done" self._state = 'done'
def suspend(self) -> None: def suspend(self) -> None:
self._assert_state("suspend", ("started", "suspended")) self._assert_state('suspend', ('started', 'suspended'))
setattr(sys, self.name, self._old) setattr(sys, self.name, self._old)
self._state = "suspended" self._state = 'suspended'
def resume(self) -> None: def resume(self) -> None:
self._assert_state("resume", ("started", "suspended")) self._assert_state('resume', ('started', 'suspended'))
if self._state == "started": if self._state == 'started':
return return
setattr(sys, self.name, self.tmpfile) setattr(sys, self.name, self.tmpfile)
self._state = "started" self._state = 'started'
class SysCaptureBinary(SysCaptureBase[bytes]): class SysCaptureBinary(SysCaptureBase[bytes]):
EMPTY_BUFFER = b"" EMPTY_BUFFER = b''
def snap(self) -> bytes: def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended")) self._assert_state('snap', ('started', 'suspended'))
self.tmpfile.seek(0) self.tmpfile.seek(0)
res = self.tmpfile.buffer.read() res = self.tmpfile.buffer.read()
self.tmpfile.seek(0) self.tmpfile.seek(0)
@ -416,17 +418,17 @@ class SysCaptureBinary(SysCaptureBase[bytes]):
return res return res
def writeorg(self, data: bytes) -> None: def writeorg(self, data: bytes) -> None:
self._assert_state("writeorg", ("started", "suspended")) self._assert_state('writeorg', ('started', 'suspended'))
self._old.flush() self._old.flush()
self._old.buffer.write(data) self._old.buffer.write(data)
self._old.buffer.flush() self._old.buffer.flush()
class SysCapture(SysCaptureBase[str]): class SysCapture(SysCaptureBase[str]):
EMPTY_BUFFER = "" EMPTY_BUFFER = ''
def snap(self) -> str: def snap(self) -> str:
self._assert_state("snap", ("started", "suspended")) self._assert_state('snap', ('started', 'suspended'))
assert isinstance(self.tmpfile, CaptureIO) assert isinstance(self.tmpfile, CaptureIO)
res = self.tmpfile.getvalue() res = self.tmpfile.getvalue()
self.tmpfile.seek(0) self.tmpfile.seek(0)
@ -434,7 +436,7 @@ class SysCapture(SysCaptureBase[str]):
return res return res
def writeorg(self, data: str) -> None: def writeorg(self, data: str) -> None:
self._assert_state("writeorg", ("started", "suspended")) self._assert_state('writeorg', ('started', 'suspended'))
self._old.write(data) self._old.write(data)
self._old.flush() self._old.flush()
@ -457,21 +459,21 @@ class FDCaptureBase(CaptureBase[AnyStr]):
# Further complications are the need to support suspend() and the # Further complications are the need to support suspend() and the
# possibility of FD reuse (e.g. the tmpfile getting the very same # possibility of FD reuse (e.g. the tmpfile getting the very same
# target FD). The following approach is robust, I believe. # target FD). The following approach is robust, I believe.
self.targetfd_invalid: Optional[int] = os.open(os.devnull, os.O_RDWR) self.targetfd_invalid: int | None = os.open(os.devnull, os.O_RDWR)
os.dup2(self.targetfd_invalid, targetfd) os.dup2(self.targetfd_invalid, targetfd)
else: else:
self.targetfd_invalid = None self.targetfd_invalid = None
self.targetfd_save = os.dup(targetfd) self.targetfd_save = os.dup(targetfd)
if targetfd == 0: if targetfd == 0:
self.tmpfile = open(os.devnull, encoding="utf-8") self.tmpfile = open(os.devnull, encoding='utf-8')
self.syscapture: CaptureBase[str] = SysCapture(targetfd) self.syscapture: CaptureBase[str] = SysCapture(targetfd)
else: else:
self.tmpfile = EncodedFile( self.tmpfile = EncodedFile(
TemporaryFile(buffering=0), TemporaryFile(buffering=0),
encoding="utf-8", encoding='utf-8',
errors="replace", errors='replace',
newline="", newline='',
write_through=True, write_through=True,
) )
if targetfd in patchsysdict: if targetfd in patchsysdict:
@ -479,10 +481,10 @@ class FDCaptureBase(CaptureBase[AnyStr]):
else: else:
self.syscapture = NoCapture(targetfd) self.syscapture = NoCapture(targetfd)
self._state = "initialized" self._state = 'initialized'
def __repr__(self) -> str: def __repr__(self) -> str:
return "<{} {} oldfd={} _state={!r} tmpfile={!r}>".format( return '<{} {} oldfd={} _state={!r} tmpfile={!r}>'.format(
self.__class__.__name__, self.__class__.__name__,
self.targetfd, self.targetfd,
self.targetfd_save, self.targetfd_save,
@ -490,25 +492,25 @@ class FDCaptureBase(CaptureBase[AnyStr]):
self.tmpfile, self.tmpfile,
) )
def _assert_state(self, op: str, states: Tuple[str, ...]) -> None: def _assert_state(self, op: str, states: tuple[str, ...]) -> None:
assert ( assert (
self._state in states self._state in states
), "cannot {} in state {!r}: expected one of {}".format( ), 'cannot {} in state {!r}: expected one of {}'.format(
op, self._state, ", ".join(states) op, self._state, ', '.join(states),
) )
def start(self) -> None: def start(self) -> None:
"""Start capturing on targetfd using memorized tmpfile.""" """Start capturing on targetfd using memorized tmpfile."""
self._assert_state("start", ("initialized",)) self._assert_state('start', ('initialized',))
os.dup2(self.tmpfile.fileno(), self.targetfd) os.dup2(self.tmpfile.fileno(), self.targetfd)
self.syscapture.start() self.syscapture.start()
self._state = "started" self._state = 'started'
def done(self) -> None: def done(self) -> None:
"""Stop capturing, restore streams, return original capture file, """Stop capturing, restore streams, return original capture file,
seeked to position zero.""" seeked to position zero."""
self._assert_state("done", ("initialized", "started", "suspended", "done")) self._assert_state('done', ('initialized', 'started', 'suspended', 'done'))
if self._state == "done": if self._state == 'done':
return return
os.dup2(self.targetfd_save, self.targetfd) os.dup2(self.targetfd_save, self.targetfd)
os.close(self.targetfd_save) os.close(self.targetfd_save)
@ -518,23 +520,23 @@ class FDCaptureBase(CaptureBase[AnyStr]):
os.close(self.targetfd_invalid) os.close(self.targetfd_invalid)
self.syscapture.done() self.syscapture.done()
self.tmpfile.close() self.tmpfile.close()
self._state = "done" self._state = 'done'
def suspend(self) -> None: def suspend(self) -> None:
self._assert_state("suspend", ("started", "suspended")) self._assert_state('suspend', ('started', 'suspended'))
if self._state == "suspended": if self._state == 'suspended':
return return
self.syscapture.suspend() self.syscapture.suspend()
os.dup2(self.targetfd_save, self.targetfd) os.dup2(self.targetfd_save, self.targetfd)
self._state = "suspended" self._state = 'suspended'
def resume(self) -> None: def resume(self) -> None:
self._assert_state("resume", ("started", "suspended")) self._assert_state('resume', ('started', 'suspended'))
if self._state == "started": if self._state == 'started':
return return
self.syscapture.resume() self.syscapture.resume()
os.dup2(self.tmpfile.fileno(), self.targetfd) os.dup2(self.tmpfile.fileno(), self.targetfd)
self._state = "started" self._state = 'started'
class FDCaptureBinary(FDCaptureBase[bytes]): class FDCaptureBinary(FDCaptureBase[bytes]):
@ -543,10 +545,10 @@ class FDCaptureBinary(FDCaptureBase[bytes]):
snap() produces `bytes`. snap() produces `bytes`.
""" """
EMPTY_BUFFER = b"" EMPTY_BUFFER = b''
def snap(self) -> bytes: def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended")) self._assert_state('snap', ('started', 'suspended'))
self.tmpfile.seek(0) self.tmpfile.seek(0)
res = self.tmpfile.buffer.read() res = self.tmpfile.buffer.read()
self.tmpfile.seek(0) self.tmpfile.seek(0)
@ -555,7 +557,7 @@ class FDCaptureBinary(FDCaptureBase[bytes]):
def writeorg(self, data: bytes) -> None: def writeorg(self, data: bytes) -> None:
"""Write to original file descriptor.""" """Write to original file descriptor."""
self._assert_state("writeorg", ("started", "suspended")) self._assert_state('writeorg', ('started', 'suspended'))
os.write(self.targetfd_save, data) os.write(self.targetfd_save, data)
@ -565,10 +567,10 @@ class FDCapture(FDCaptureBase[str]):
snap() produces text. snap() produces text.
""" """
EMPTY_BUFFER = "" EMPTY_BUFFER = ''
def snap(self) -> str: def snap(self) -> str:
self._assert_state("snap", ("started", "suspended")) self._assert_state('snap', ('started', 'suspended'))
self.tmpfile.seek(0) self.tmpfile.seek(0)
res = self.tmpfile.read() res = self.tmpfile.read()
self.tmpfile.seek(0) self.tmpfile.seek(0)
@ -577,9 +579,9 @@ class FDCapture(FDCaptureBase[str]):
def writeorg(self, data: str) -> None: def writeorg(self, data: str) -> None:
"""Write to original file descriptor.""" """Write to original file descriptor."""
self._assert_state("writeorg", ("started", "suspended")) self._assert_state('writeorg', ('started', 'suspended'))
# XXX use encoding of original stream # XXX use encoding of original stream
os.write(self.targetfd_save, data.encode("utf-8")) os.write(self.targetfd_save, data.encode('utf-8'))
# MultiCapture # MultiCapture
@ -598,7 +600,7 @@ if sys.version_info >= (3, 11) or TYPE_CHECKING:
else: else:
class CaptureResult( class CaptureResult(
collections.namedtuple("CaptureResult", ["out", "err"]), # noqa: PYI024 collections.namedtuple('CaptureResult', ['out', 'err']), # noqa: PYI024
Generic[AnyStr], Generic[AnyStr],
): ):
"""The result of :method:`caplog.readouterr() <pytest.CaptureFixture.readouterr>`.""" """The result of :method:`caplog.readouterr() <pytest.CaptureFixture.readouterr>`."""
@ -612,16 +614,16 @@ class MultiCapture(Generic[AnyStr]):
def __init__( def __init__(
self, self,
in_: Optional[CaptureBase[AnyStr]], in_: CaptureBase[AnyStr] | None,
out: Optional[CaptureBase[AnyStr]], out: CaptureBase[AnyStr] | None,
err: Optional[CaptureBase[AnyStr]], err: CaptureBase[AnyStr] | None,
) -> None: ) -> None:
self.in_: Optional[CaptureBase[AnyStr]] = in_ self.in_: CaptureBase[AnyStr] | None = in_
self.out: Optional[CaptureBase[AnyStr]] = out self.out: CaptureBase[AnyStr] | None = out
self.err: Optional[CaptureBase[AnyStr]] = err self.err: CaptureBase[AnyStr] | None = err
def __repr__(self) -> str: def __repr__(self) -> str:
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format( return '<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>'.format(
self.out, self.out,
self.err, self.err,
self.in_, self.in_,
@ -630,7 +632,7 @@ class MultiCapture(Generic[AnyStr]):
) )
def start_capturing(self) -> None: def start_capturing(self) -> None:
self._state = "started" self._state = 'started'
if self.in_: if self.in_:
self.in_.start() self.in_.start()
if self.out: if self.out:
@ -638,7 +640,7 @@ class MultiCapture(Generic[AnyStr]):
if self.err: if self.err:
self.err.start() self.err.start()
def pop_outerr_to_orig(self) -> Tuple[AnyStr, AnyStr]: def pop_outerr_to_orig(self) -> tuple[AnyStr, AnyStr]:
"""Pop current snapshot out/err capture and flush to orig streams.""" """Pop current snapshot out/err capture and flush to orig streams."""
out, err = self.readouterr() out, err = self.readouterr()
if out: if out:
@ -650,7 +652,7 @@ class MultiCapture(Generic[AnyStr]):
return out, err return out, err
def suspend_capturing(self, in_: bool = False) -> None: def suspend_capturing(self, in_: bool = False) -> None:
self._state = "suspended" self._state = 'suspended'
if self.out: if self.out:
self.out.suspend() self.out.suspend()
if self.err: if self.err:
@ -660,7 +662,7 @@ class MultiCapture(Generic[AnyStr]):
self._in_suspended = True self._in_suspended = True
def resume_capturing(self) -> None: def resume_capturing(self) -> None:
self._state = "started" self._state = 'started'
if self.out: if self.out:
self.out.resume() self.out.resume()
if self.err: if self.err:
@ -672,9 +674,9 @@ class MultiCapture(Generic[AnyStr]):
def stop_capturing(self) -> None: def stop_capturing(self) -> None:
"""Stop capturing and reset capturing streams.""" """Stop capturing and reset capturing streams."""
if self._state == "stopped": if self._state == 'stopped':
raise ValueError("was already stopped") raise ValueError('was already stopped')
self._state = "stopped" self._state = 'stopped'
if self.out: if self.out:
self.out.done() self.out.done()
if self.err: if self.err:
@ -684,27 +686,27 @@ class MultiCapture(Generic[AnyStr]):
def is_started(self) -> bool: def is_started(self) -> bool:
"""Whether actively capturing -- not suspended or stopped.""" """Whether actively capturing -- not suspended or stopped."""
return self._state == "started" return self._state == 'started'
def readouterr(self) -> CaptureResult[AnyStr]: def readouterr(self) -> CaptureResult[AnyStr]:
out = self.out.snap() if self.out else "" out = self.out.snap() if self.out else ''
err = self.err.snap() if self.err else "" err = self.err.snap() if self.err else ''
# TODO: This type error is real, need to fix. # TODO: This type error is real, need to fix.
return CaptureResult(out, err) # type: ignore[arg-type] return CaptureResult(out, err) # type: ignore[arg-type]
def _get_multicapture(method: _CaptureMethod) -> MultiCapture[str]: def _get_multicapture(method: _CaptureMethod) -> MultiCapture[str]:
if method == "fd": if method == 'fd':
return MultiCapture(in_=FDCapture(0), out=FDCapture(1), err=FDCapture(2)) return MultiCapture(in_=FDCapture(0), out=FDCapture(1), err=FDCapture(2))
elif method == "sys": elif method == 'sys':
return MultiCapture(in_=SysCapture(0), out=SysCapture(1), err=SysCapture(2)) return MultiCapture(in_=SysCapture(0), out=SysCapture(1), err=SysCapture(2))
elif method == "no": elif method == 'no':
return MultiCapture(in_=None, out=None, err=None) return MultiCapture(in_=None, out=None, err=None)
elif method == "tee-sys": elif method == 'tee-sys':
return MultiCapture( return MultiCapture(
in_=None, out=SysCapture(1, tee=True), err=SysCapture(2, tee=True) in_=None, out=SysCapture(1, tee=True), err=SysCapture(2, tee=True),
) )
raise ValueError(f"unknown capturing method: {method!r}") raise ValueError(f'unknown capturing method: {method!r}')
# CaptureManager and CaptureFixture # CaptureManager and CaptureFixture
@ -731,25 +733,25 @@ class CaptureManager:
def __init__(self, method: _CaptureMethod) -> None: def __init__(self, method: _CaptureMethod) -> None:
self._method: Final = method self._method: Final = method
self._global_capturing: Optional[MultiCapture[str]] = None self._global_capturing: MultiCapture[str] | None = None
self._capture_fixture: Optional[CaptureFixture[Any]] = None self._capture_fixture: CaptureFixture[Any] | None = None
def __repr__(self) -> str: def __repr__(self) -> str:
return "<CaptureManager _method={!r} _global_capturing={!r} _capture_fixture={!r}>".format( return '<CaptureManager _method={!r} _global_capturing={!r} _capture_fixture={!r}>'.format(
self._method, self._global_capturing, self._capture_fixture self._method, self._global_capturing, self._capture_fixture,
) )
def is_capturing(self) -> Union[str, bool]: def is_capturing(self) -> str | bool:
if self.is_globally_capturing(): if self.is_globally_capturing():
return "global" return 'global'
if self._capture_fixture: if self._capture_fixture:
return "fixture %s" % self._capture_fixture.request.fixturename return 'fixture %s' % self._capture_fixture.request.fixturename
return False return False
# Global capturing control # Global capturing control
def is_globally_capturing(self) -> bool: def is_globally_capturing(self) -> bool:
return self._method != "no" return self._method != 'no'
def start_global_capturing(self) -> None: def start_global_capturing(self) -> None:
assert self._global_capturing is None assert self._global_capturing is None
@ -787,12 +789,12 @@ class CaptureManager:
# Fixture Control # Fixture Control
def set_fixture(self, capture_fixture: "CaptureFixture[Any]") -> None: def set_fixture(self, capture_fixture: CaptureFixture[Any]) -> None:
if self._capture_fixture: if self._capture_fixture:
current_fixture = self._capture_fixture.request.fixturename current_fixture = self._capture_fixture.request.fixturename
requested_fixture = capture_fixture.request.fixturename requested_fixture = capture_fixture.request.fixturename
capture_fixture.request.raiseerror( capture_fixture.request.raiseerror(
f"cannot use {requested_fixture} and {current_fixture} at the same time" f'cannot use {requested_fixture} and {current_fixture} at the same time',
) )
self._capture_fixture = capture_fixture self._capture_fixture = capture_fixture
@ -848,14 +850,14 @@ class CaptureManager:
self.suspend_global_capture(in_=False) self.suspend_global_capture(in_=False)
out, err = self.read_global_capture() out, err = self.read_global_capture()
item.add_report_section(when, "stdout", out) item.add_report_section(when, 'stdout', out)
item.add_report_section(when, "stderr", err) item.add_report_section(when, 'stderr', err)
# Hooks # Hooks
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_make_collect_report( def pytest_make_collect_report(
self, collector: Collector self, collector: Collector,
) -> Generator[None, CollectReport, CollectReport]: ) -> Generator[None, CollectReport, CollectReport]:
if isinstance(collector, File): if isinstance(collector, File):
self.resume_global_capture() self.resume_global_capture()
@ -865,26 +867,26 @@ class CaptureManager:
self.suspend_global_capture() self.suspend_global_capture()
out, err = self.read_global_capture() out, err = self.read_global_capture()
if out: if out:
rep.sections.append(("Captured stdout", out)) rep.sections.append(('Captured stdout', out))
if err: if err:
rep.sections.append(("Captured stderr", err)) rep.sections.append(('Captured stderr', err))
else: else:
rep = yield rep = yield
return rep return rep
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_runtest_setup(self, item: Item) -> Generator[None, None, None]: def pytest_runtest_setup(self, item: Item) -> Generator[None, None, None]:
with self.item_capture("setup", item): with self.item_capture('setup', item):
return (yield) return (yield)
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_runtest_call(self, item: Item) -> Generator[None, None, None]: def pytest_runtest_call(self, item: Item) -> Generator[None, None, None]:
with self.item_capture("call", item): with self.item_capture('call', item):
return (yield) return (yield)
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_runtest_teardown(self, item: Item) -> Generator[None, None, None]: def pytest_runtest_teardown(self, item: Item) -> Generator[None, None, None]:
with self.item_capture("teardown", item): with self.item_capture('teardown', item):
return (yield) return (yield)
@hookimpl(tryfirst=True) @hookimpl(tryfirst=True)
@ -902,15 +904,15 @@ class CaptureFixture(Generic[AnyStr]):
def __init__( def __init__(
self, self,
captureclass: Type[CaptureBase[AnyStr]], captureclass: type[CaptureBase[AnyStr]],
request: SubRequest, request: SubRequest,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self.captureclass: Type[CaptureBase[AnyStr]] = captureclass self.captureclass: type[CaptureBase[AnyStr]] = captureclass
self.request = request self.request = request
self._capture: Optional[MultiCapture[AnyStr]] = None self._capture: MultiCapture[AnyStr] | None = None
self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER
self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER
@ -968,7 +970,7 @@ class CaptureFixture(Generic[AnyStr]):
def disabled(self) -> Generator[None, None, None]: def disabled(self) -> Generator[None, None, None]:
"""Temporarily disable capturing while inside the ``with`` block.""" """Temporarily disable capturing while inside the ``with`` block."""
capmanager: CaptureManager = self.request.config.pluginmanager.getplugin( capmanager: CaptureManager = self.request.config.pluginmanager.getplugin(
"capturemanager" 'capturemanager',
) )
with capmanager.global_and_fixture_disabled(): with capmanager.global_and_fixture_disabled():
yield yield
@ -995,7 +997,7 @@ def capsys(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == "hello\n" assert captured.out == "hello\n"
""" """
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") capman: CaptureManager = request.config.pluginmanager.getplugin('capturemanager')
capture_fixture = CaptureFixture(SysCapture, request, _ispytest=True) capture_fixture = CaptureFixture(SysCapture, request, _ispytest=True)
capman.set_fixture(capture_fixture) capman.set_fixture(capture_fixture)
capture_fixture._start() capture_fixture._start()
@ -1022,7 +1024,7 @@ def capsysbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None,
captured = capsysbinary.readouterr() captured = capsysbinary.readouterr()
assert captured.out == b"hello\n" assert captured.out == b"hello\n"
""" """
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") capman: CaptureManager = request.config.pluginmanager.getplugin('capturemanager')
capture_fixture = CaptureFixture(SysCaptureBinary, request, _ispytest=True) capture_fixture = CaptureFixture(SysCaptureBinary, request, _ispytest=True)
capman.set_fixture(capture_fixture) capman.set_fixture(capture_fixture)
capture_fixture._start() capture_fixture._start()
@ -1049,7 +1051,7 @@ def capfd(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
captured = capfd.readouterr() captured = capfd.readouterr()
assert captured.out == "hello\n" assert captured.out == "hello\n"
""" """
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") capman: CaptureManager = request.config.pluginmanager.getplugin('capturemanager')
capture_fixture = CaptureFixture(FDCapture, request, _ispytest=True) capture_fixture = CaptureFixture(FDCapture, request, _ispytest=True)
capman.set_fixture(capture_fixture) capman.set_fixture(capture_fixture)
capture_fixture._start() capture_fixture._start()
@ -1077,7 +1079,7 @@ def capfdbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, N
assert captured.out == b"hello\n" assert captured.out == b"hello\n"
""" """
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") capman: CaptureManager = request.config.pluginmanager.getplugin('capturemanager')
capture_fixture = CaptureFixture(FDCaptureBinary, request, _ispytest=True) capture_fixture = CaptureFixture(FDCaptureBinary, request, _ispytest=True)
capman.set_fixture(capture_fixture) capman.set_fixture(capture_fixture)
capture_fixture._start() capture_fixture._start()

View file

@ -1,17 +1,16 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Python version compatibility code.""" """Python version compatibility code."""
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import enum import enum
import functools import functools
import inspect import inspect
import os
import sys
from inspect import Parameter from inspect import Parameter
from inspect import signature from inspect import signature
import os
from pathlib import Path from pathlib import Path
import sys
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Final from typing import Final
@ -57,7 +56,7 @@ def iscoroutinefunction(func: object) -> bool:
importing asyncio directly, which in turns also initializes the "logging" importing asyncio directly, which in turns also initializes the "logging"
module as a side-effect (see issue #8). module as a side-effect (see issue #8).
""" """
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False) return inspect.iscoroutinefunction(func) or getattr(func, '_is_coroutine', False)
def is_async_function(func: object) -> bool: def is_async_function(func: object) -> bool:
@ -76,33 +75,33 @@ def getlocation(function, curdir: str | os.PathLike[str] | None = None) -> str:
except ValueError: except ValueError:
pass pass
else: else:
return "%s:%d" % (relfn, lineno + 1) return '%s:%d' % (relfn, lineno + 1)
return "%s:%d" % (fn, lineno + 1) return '%s:%d' % (fn, lineno + 1)
def num_mock_patch_args(function) -> int: def num_mock_patch_args(function) -> int:
"""Return number of arguments used up by mock arguments (if any).""" """Return number of arguments used up by mock arguments (if any)."""
patchings = getattr(function, "patchings", None) patchings = getattr(function, 'patchings', None)
if not patchings: if not patchings:
return 0 return 0
mock_sentinel = getattr(sys.modules.get("mock"), "DEFAULT", object()) mock_sentinel = getattr(sys.modules.get('mock'), 'DEFAULT', object())
ut_mock_sentinel = getattr(sys.modules.get("unittest.mock"), "DEFAULT", object()) ut_mock_sentinel = getattr(sys.modules.get('unittest.mock'), 'DEFAULT', object())
return len( return len(
[ [
p p
for p in patchings for p in patchings
if not p.attribute_name if not p.attribute_name and
and (p.new is mock_sentinel or p.new is ut_mock_sentinel) (p.new is mock_sentinel or p.new is ut_mock_sentinel)
] ],
) )
def getfuncargnames( def getfuncargnames(
function: Callable[..., object], function: Callable[..., object],
*, *,
name: str = "", name: str = '',
is_method: bool = False, is_method: bool = False,
cls: type | None = None, cls: type | None = None,
) -> tuple[str, ...]: ) -> tuple[str, ...]:
@ -135,7 +134,7 @@ def getfuncargnames(
from _pytest.outcomes import fail from _pytest.outcomes import fail
fail( fail(
f"Could not determine arguments of {function!r}: {e}", f'Could not determine arguments of {function!r}: {e}',
pytrace=False, pytrace=False,
) )
@ -143,10 +142,10 @@ def getfuncargnames(
p.name p.name
for p in parameters.values() for p in parameters.values()
if ( if (
p.kind is Parameter.POSITIONAL_OR_KEYWORD p.kind is Parameter.POSITIONAL_OR_KEYWORD or
or p.kind is Parameter.KEYWORD_ONLY p.kind is Parameter.KEYWORD_ONLY
) ) and
and p.default is Parameter.empty p.default is Parameter.empty
) )
if not name: if not name:
name = function.__name__ name = function.__name__
@ -157,15 +156,15 @@ def getfuncargnames(
if is_method or ( if is_method or (
# Not using `getattr` because we don't want to resolve the staticmethod. # Not using `getattr` because we don't want to resolve the staticmethod.
# Not using `cls.__dict__` because we want to check the entire MRO. # Not using `cls.__dict__` because we want to check the entire MRO.
cls cls and
and not isinstance( not isinstance(
inspect.getattr_static(cls, name, default=None), staticmethod inspect.getattr_static(cls, name, default=None), staticmethod,
) )
): ):
arg_names = arg_names[1:] arg_names = arg_names[1:]
# Remove any names that will be replaced with mocks. # Remove any names that will be replaced with mocks.
if hasattr(function, "__wrapped__"): if hasattr(function, '__wrapped__'):
arg_names = arg_names[num_mock_patch_args(function) :] arg_names = arg_names[num_mock_patch_args(function):]
return arg_names return arg_names
@ -176,16 +175,16 @@ def get_default_arg_names(function: Callable[..., Any]) -> tuple[str, ...]:
return tuple( return tuple(
p.name p.name
for p in signature(function).parameters.values() for p in signature(function).parameters.values()
if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY) if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY) and
and p.default is not Parameter.empty p.default is not Parameter.empty
) )
_non_printable_ascii_translate_table = { _non_printable_ascii_translate_table = {
i: f"\\x{i:02x}" for i in range(128) if i not in range(32, 127) i: f'\\x{i:02x}' for i in range(128) if i not in range(32, 127)
} }
_non_printable_ascii_translate_table.update( _non_printable_ascii_translate_table.update(
{ord("\t"): "\\t", ord("\r"): "\\r", ord("\n"): "\\n"} {ord('\t'): '\\t', ord('\r'): '\\r', ord('\n'): '\\n'},
) )
@ -206,9 +205,9 @@ def ascii_escaped(val: bytes | str) -> str:
a UTF-8 string. a UTF-8 string.
""" """
if isinstance(val, bytes): if isinstance(val, bytes):
ret = val.decode("ascii", "backslashreplace") ret = val.decode('ascii', 'backslashreplace')
else: else:
ret = val.encode("unicode_escape").decode("ascii") ret = val.encode('unicode_escape').decode('ascii')
return ret.translate(_non_printable_ascii_translate_table) return ret.translate(_non_printable_ascii_translate_table)
@ -232,11 +231,11 @@ def get_real_func(obj):
# __pytest_wrapped__ is set by @pytest.fixture when wrapping the fixture function # __pytest_wrapped__ is set by @pytest.fixture when wrapping the fixture function
# to trigger a warning if it gets called directly instead of by pytest: we don't # to trigger a warning if it gets called directly instead of by pytest: we don't
# want to unwrap further than this otherwise we lose useful wrappings like @mock.patch (#3774) # want to unwrap further than this otherwise we lose useful wrappings like @mock.patch (#3774)
new_obj = getattr(obj, "__pytest_wrapped__", None) new_obj = getattr(obj, '__pytest_wrapped__', None)
if isinstance(new_obj, _PytestWrapper): if isinstance(new_obj, _PytestWrapper):
obj = new_obj.obj obj = new_obj.obj
break break
new_obj = getattr(obj, "__wrapped__", None) new_obj = getattr(obj, '__wrapped__', None)
if new_obj is None: if new_obj is None:
break break
obj = new_obj obj = new_obj
@ -244,7 +243,7 @@ def get_real_func(obj):
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
raise ValueError( raise ValueError(
f"could not find real function of {saferepr(start_obj)}\nstopped at {saferepr(obj)}" f'could not find real function of {saferepr(start_obj)}\nstopped at {saferepr(obj)}',
) )
if isinstance(obj, functools.partial): if isinstance(obj, functools.partial):
obj = obj.func obj = obj.func
@ -256,11 +255,11 @@ def get_real_method(obj, holder):
``obj``, while at the same time returning a bound method to ``holder`` if ``obj``, while at the same time returning a bound method to ``holder`` if
the original object was a bound method.""" the original object was a bound method."""
try: try:
is_method = hasattr(obj, "__func__") is_method = hasattr(obj, '__func__')
obj = get_real_func(obj) obj = get_real_func(obj)
except Exception: # pragma: no cover except Exception: # pragma: no cover
return obj return obj
if is_method and hasattr(obj, "__get__") and callable(obj.__get__): if is_method and hasattr(obj, '__get__') and callable(obj.__get__):
obj = obj.__get__(holder) obj = obj.__get__(holder)
return obj return obj
@ -306,7 +305,7 @@ def get_user_id() -> int | None:
# mypy follows the version and platform checking expectation of PEP 484: # mypy follows the version and platform checking expectation of PEP 484:
# https://mypy.readthedocs.io/en/stable/common_issues.html?highlight=platform#python-version-and-system-platform-checks # https://mypy.readthedocs.io/en/stable/common_issues.html?highlight=platform#python-version-and-system-platform-checks
# Containment checks are too complex for mypy v1.5.0 and cause failure. # Containment checks are too complex for mypy v1.5.0 and cause failure.
if sys.platform == "win32" or sys.platform == "emscripten": # noqa: PLR1714 if sys.platform == 'win32' or sys.platform == 'emscripten': # noqa: PLR1714
# win32 does not have a getuid() function. # win32 does not have a getuid() function.
# Emscripten has a return 0 stub. # Emscripten has a return 0 stub.
return None return None
@ -350,4 +349,4 @@ def get_user_id() -> int | None:
# #
# This also work for Enums (if you use `is` to compare) and Literals. # This also work for Enums (if you use `is` to compare) and Literals.
def assert_never(value: NoReturn) -> NoReturn: def assert_never(value: NoReturn) -> NoReturn:
assert False, f"Unhandled value: {value} ({type(value).__name__})" assert False, f'Unhandled value: {value} ({type(value).__name__})'

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,10 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import argparse import argparse
from gettext import gettext
import os import os
import sys import sys
from gettext import gettext
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import cast from typing import cast
@ -22,12 +24,12 @@ from _pytest.config.exceptions import UsageError
from _pytest.deprecated import check_ispytest from _pytest.deprecated import check_ispytest
FILE_OR_DIR = "file_or_dir" FILE_OR_DIR = 'file_or_dir'
class NotSet: class NotSet:
def __repr__(self) -> str: def __repr__(self) -> str:
return "<notset>" return '<notset>'
NOT_SET = NotSet() NOT_SET = NotSet()
@ -41,32 +43,32 @@ class Parser:
there's an error processing the command line arguments. there's an error processing the command line arguments.
""" """
prog: Optional[str] = None prog: str | None = None
def __init__( def __init__(
self, self,
usage: Optional[str] = None, usage: str | None = None,
processopt: Optional[Callable[["Argument"], None]] = None, processopt: Callable[[Argument], None] | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self._anonymous = OptionGroup("Custom options", parser=self, _ispytest=True) self._anonymous = OptionGroup('Custom options', parser=self, _ispytest=True)
self._groups: List[OptionGroup] = [] self._groups: list[OptionGroup] = []
self._processopt = processopt self._processopt = processopt
self._usage = usage self._usage = usage
self._inidict: Dict[str, Tuple[str, Optional[str], Any]] = {} self._inidict: dict[str, tuple[str, str | None, Any]] = {}
self._ininames: List[str] = [] self._ininames: list[str] = []
self.extra_info: Dict[str, Any] = {} self.extra_info: dict[str, Any] = {}
def processoption(self, option: "Argument") -> None: def processoption(self, option: Argument) -> None:
if self._processopt: if self._processopt:
if option.dest: if option.dest:
self._processopt(option) self._processopt(option)
def getgroup( def getgroup(
self, name: str, description: str = "", after: Optional[str] = None self, name: str, description: str = '', after: str | None = None,
) -> "OptionGroup": ) -> OptionGroup:
"""Get (or create) a named option Group. """Get (or create) a named option Group.
:param name: Name of the option group. :param name: Name of the option group.
@ -108,8 +110,8 @@ class Parser:
def parse( def parse(
self, self,
args: Sequence[Union[str, "os.PathLike[str]"]], args: Sequence[str | os.PathLike[str]],
namespace: Optional[argparse.Namespace] = None, namespace: argparse.Namespace | None = None,
) -> argparse.Namespace: ) -> argparse.Namespace:
from _pytest._argcomplete import try_argcomplete from _pytest._argcomplete import try_argcomplete
@ -118,7 +120,7 @@ class Parser:
strargs = [os.fspath(x) for x in args] strargs = [os.fspath(x) for x in args]
return self.optparser.parse_args(strargs, namespace=namespace) return self.optparser.parse_args(strargs, namespace=namespace)
def _getparser(self) -> "MyOptionParser": def _getparser(self) -> MyOptionParser:
from _pytest._argcomplete import filescompleter from _pytest._argcomplete import filescompleter
optparser = MyOptionParser(self, self.extra_info, prog=self.prog) optparser = MyOptionParser(self, self.extra_info, prog=self.prog)
@ -131,7 +133,7 @@ class Parser:
n = option.names() n = option.names()
a = option.attrs() a = option.attrs()
arggroup.add_argument(*n, **a) arggroup.add_argument(*n, **a)
file_or_dir_arg = optparser.add_argument(FILE_OR_DIR, nargs="*") file_or_dir_arg = optparser.add_argument(FILE_OR_DIR, nargs='*')
# bash like autocompletion for dirs (appending '/') # bash like autocompletion for dirs (appending '/')
# Type ignored because typeshed doesn't know about argcomplete. # Type ignored because typeshed doesn't know about argcomplete.
file_or_dir_arg.completer = filescompleter # type: ignore file_or_dir_arg.completer = filescompleter # type: ignore
@ -139,10 +141,10 @@ class Parser:
def parse_setoption( def parse_setoption(
self, self,
args: Sequence[Union[str, "os.PathLike[str]"]], args: Sequence[str | os.PathLike[str]],
option: argparse.Namespace, option: argparse.Namespace,
namespace: Optional[argparse.Namespace] = None, namespace: argparse.Namespace | None = None,
) -> List[str]: ) -> list[str]:
parsedoption = self.parse(args, namespace=namespace) parsedoption = self.parse(args, namespace=namespace)
for name, value in parsedoption.__dict__.items(): for name, value in parsedoption.__dict__.items():
setattr(option, name, value) setattr(option, name, value)
@ -150,8 +152,8 @@ class Parser:
def parse_known_args( def parse_known_args(
self, self,
args: Sequence[Union[str, "os.PathLike[str]"]], args: Sequence[str | os.PathLike[str]],
namespace: Optional[argparse.Namespace] = None, namespace: argparse.Namespace | None = None,
) -> argparse.Namespace: ) -> argparse.Namespace:
"""Parse the known arguments at this point. """Parse the known arguments at this point.
@ -161,9 +163,9 @@ class Parser:
def parse_known_and_unknown_args( def parse_known_and_unknown_args(
self, self,
args: Sequence[Union[str, "os.PathLike[str]"]], args: Sequence[str | os.PathLike[str]],
namespace: Optional[argparse.Namespace] = None, namespace: argparse.Namespace | None = None,
) -> Tuple[argparse.Namespace, List[str]]: ) -> tuple[argparse.Namespace, list[str]]:
"""Parse the known arguments at this point, and also return the """Parse the known arguments at this point, and also return the
remaining unknown arguments. remaining unknown arguments.
@ -179,9 +181,9 @@ class Parser:
self, self,
name: str, name: str,
help: str, help: str,
type: Optional[ type: None | (
Literal["string", "paths", "pathlist", "args", "linelist", "bool"] Literal['string', 'paths', 'pathlist', 'args', 'linelist', 'bool']
] = None, ) = None,
default: Any = NOT_SET, default: Any = NOT_SET,
) -> None: ) -> None:
"""Register an ini-file option. """Register an ini-file option.
@ -215,7 +217,7 @@ class Parser:
The value of ini-variables can be retrieved via a call to The value of ini-variables can be retrieved via a call to
:py:func:`config.getini(name) <pytest.Config.getini>`. :py:func:`config.getini(name) <pytest.Config.getini>`.
""" """
assert type in (None, "string", "paths", "pathlist", "args", "linelist", "bool") assert type in (None, 'string', 'paths', 'pathlist', 'args', 'linelist', 'bool')
if default is NOT_SET: if default is NOT_SET:
default = get_ini_default_for_type(type) default = get_ini_default_for_type(type)
@ -224,33 +226,33 @@ class Parser:
def get_ini_default_for_type( def get_ini_default_for_type(
type: Optional[Literal["string", "paths", "pathlist", "args", "linelist", "bool"]], type: Literal['string', 'paths', 'pathlist', 'args', 'linelist', 'bool'] | None,
) -> Any: ) -> Any:
""" """
Used by addini to get the default value for a given ini-option type, when Used by addini to get the default value for a given ini-option type, when
default is not supplied. default is not supplied.
""" """
if type is None: if type is None:
return "" return ''
elif type in ("paths", "pathlist", "args", "linelist"): elif type in ('paths', 'pathlist', 'args', 'linelist'):
return [] return []
elif type == "bool": elif type == 'bool':
return False return False
else: else:
return "" return ''
class ArgumentError(Exception): class ArgumentError(Exception):
"""Raised if an Argument instance is created with invalid or """Raised if an Argument instance is created with invalid or
inconsistent arguments.""" inconsistent arguments."""
def __init__(self, msg: str, option: Union["Argument", str]) -> None: def __init__(self, msg: str, option: Argument | str) -> None:
self.msg = msg self.msg = msg
self.option_id = str(option) self.option_id = str(option)
def __str__(self) -> str: def __str__(self) -> str:
if self.option_id: if self.option_id:
return f"option {self.option_id}: {self.msg}" return f'option {self.option_id}: {self.msg}'
else: else:
return self.msg return self.msg
@ -267,36 +269,36 @@ class Argument:
def __init__(self, *names: str, **attrs: Any) -> None: def __init__(self, *names: str, **attrs: Any) -> None:
"""Store params in private vars for use in add_argument.""" """Store params in private vars for use in add_argument."""
self._attrs = attrs self._attrs = attrs
self._short_opts: List[str] = [] self._short_opts: list[str] = []
self._long_opts: List[str] = [] self._long_opts: list[str] = []
try: try:
self.type = attrs["type"] self.type = attrs['type']
except KeyError: except KeyError:
pass pass
try: try:
# Attribute existence is tested in Config._processopt. # Attribute existence is tested in Config._processopt.
self.default = attrs["default"] self.default = attrs['default']
except KeyError: except KeyError:
pass pass
self._set_opt_strings(names) self._set_opt_strings(names)
dest: Optional[str] = attrs.get("dest") dest: str | None = attrs.get('dest')
if dest: if dest:
self.dest = dest self.dest = dest
elif self._long_opts: elif self._long_opts:
self.dest = self._long_opts[0][2:].replace("-", "_") self.dest = self._long_opts[0][2:].replace('-', '_')
else: else:
try: try:
self.dest = self._short_opts[0][1:] self.dest = self._short_opts[0][1:]
except IndexError as e: except IndexError as e:
self.dest = "???" # Needed for the error repr. self.dest = '???' # Needed for the error repr.
raise ArgumentError("need a long or short option", self) from e raise ArgumentError('need a long or short option', self) from e
def names(self) -> List[str]: def names(self) -> list[str]:
return self._short_opts + self._long_opts return self._short_opts + self._long_opts
def attrs(self) -> Mapping[str, Any]: def attrs(self) -> Mapping[str, Any]:
# Update any attributes set by processopt. # Update any attributes set by processopt.
attrs = "default dest help".split() attrs = 'default dest help'.split()
attrs.append(self.dest) attrs.append(self.dest)
for attr in attrs: for attr in attrs:
try: try:
@ -313,39 +315,39 @@ class Argument:
for opt in opts: for opt in opts:
if len(opt) < 2: if len(opt) < 2:
raise ArgumentError( raise ArgumentError(
"invalid option string %r: " 'invalid option string %r: '
"must be at least two characters long" % opt, 'must be at least two characters long' % opt,
self, self,
) )
elif len(opt) == 2: elif len(opt) == 2:
if not (opt[0] == "-" and opt[1] != "-"): if not (opt[0] == '-' and opt[1] != '-'):
raise ArgumentError( raise ArgumentError(
"invalid short option string %r: " 'invalid short option string %r: '
"must be of the form -x, (x any non-dash char)" % opt, 'must be of the form -x, (x any non-dash char)' % opt,
self, self,
) )
self._short_opts.append(opt) self._short_opts.append(opt)
else: else:
if not (opt[0:2] == "--" and opt[2] != "-"): if not (opt[0:2] == '--' and opt[2] != '-'):
raise ArgumentError( raise ArgumentError(
"invalid long option string %r: " 'invalid long option string %r: '
"must start with --, followed by non-dash" % opt, 'must start with --, followed by non-dash' % opt,
self, self,
) )
self._long_opts.append(opt) self._long_opts.append(opt)
def __repr__(self) -> str: def __repr__(self) -> str:
args: List[str] = [] args: list[str] = []
if self._short_opts: if self._short_opts:
args += ["_short_opts: " + repr(self._short_opts)] args += ['_short_opts: ' + repr(self._short_opts)]
if self._long_opts: if self._long_opts:
args += ["_long_opts: " + repr(self._long_opts)] args += ['_long_opts: ' + repr(self._long_opts)]
args += ["dest: " + repr(self.dest)] args += ['dest: ' + repr(self.dest)]
if hasattr(self, "type"): if hasattr(self, 'type'):
args += ["type: " + repr(self.type)] args += ['type: ' + repr(self.type)]
if hasattr(self, "default"): if hasattr(self, 'default'):
args += ["default: " + repr(self.default)] args += ['default: ' + repr(self.default)]
return "Argument({})".format(", ".join(args)) return 'Argument({})'.format(', '.join(args))
class OptionGroup: class OptionGroup:
@ -354,15 +356,15 @@ class OptionGroup:
def __init__( def __init__(
self, self,
name: str, name: str,
description: str = "", description: str = '',
parser: Optional[Parser] = None, parser: Parser | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self.name = name self.name = name
self.description = description self.description = description
self.options: List[Argument] = [] self.options: list[Argument] = []
self.parser = parser self.parser = parser
def addoption(self, *opts: str, **attrs: Any) -> None: def addoption(self, *opts: str, **attrs: Any) -> None:
@ -383,7 +385,7 @@ class OptionGroup:
name for opt in self.options for name in opt.names() name for opt in self.options for name in opt.names()
) )
if conflict: if conflict:
raise ValueError("option names %s already added" % conflict) raise ValueError('option names %s already added' % conflict)
option = Argument(*opts, **attrs) option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=False) self._addoption_instance(option, shortupper=False)
@ -391,11 +393,11 @@ class OptionGroup:
option = Argument(*opts, **attrs) option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=True) self._addoption_instance(option, shortupper=True)
def _addoption_instance(self, option: "Argument", shortupper: bool = False) -> None: def _addoption_instance(self, option: Argument, shortupper: bool = False) -> None:
if not shortupper: if not shortupper:
for opt in option._short_opts: for opt in option._short_opts:
if opt[0] == "-" and opt[1].islower(): if opt[0] == '-' and opt[1].islower():
raise ValueError("lowercase shortoptions reserved") raise ValueError('lowercase shortoptions reserved')
if self.parser: if self.parser:
self.parser.processoption(option) self.parser.processoption(option)
self.options.append(option) self.options.append(option)
@ -405,8 +407,8 @@ class MyOptionParser(argparse.ArgumentParser):
def __init__( def __init__(
self, self,
parser: Parser, parser: Parser,
extra_info: Optional[Dict[str, Any]] = None, extra_info: dict[str, Any] | None = None,
prog: Optional[str] = None, prog: str | None = None,
) -> None: ) -> None:
self._parser = parser self._parser = parser
super().__init__( super().__init__(
@ -422,29 +424,29 @@ class MyOptionParser(argparse.ArgumentParser):
def error(self, message: str) -> NoReturn: def error(self, message: str) -> NoReturn:
"""Transform argparse error message into UsageError.""" """Transform argparse error message into UsageError."""
msg = f"{self.prog}: error: {message}" msg = f'{self.prog}: error: {message}'
if hasattr(self._parser, "_config_source_hint"): if hasattr(self._parser, '_config_source_hint'):
# Type ignored because the attribute is set dynamically. # Type ignored because the attribute is set dynamically.
msg = f"{msg} ({self._parser._config_source_hint})" # type: ignore msg = f'{msg} ({self._parser._config_source_hint})' # type: ignore
raise UsageError(self.format_usage() + msg) raise UsageError(self.format_usage() + msg)
# Type ignored because typeshed has a very complex type in the superclass. # Type ignored because typeshed has a very complex type in the superclass.
def parse_args( # type: ignore def parse_args( # type: ignore
self, self,
args: Optional[Sequence[str]] = None, args: Sequence[str] | None = None,
namespace: Optional[argparse.Namespace] = None, namespace: argparse.Namespace | None = None,
) -> argparse.Namespace: ) -> argparse.Namespace:
"""Allow splitting of positional arguments.""" """Allow splitting of positional arguments."""
parsed, unrecognized = self.parse_known_args(args, namespace) parsed, unrecognized = self.parse_known_args(args, namespace)
if unrecognized: if unrecognized:
for arg in unrecognized: for arg in unrecognized:
if arg and arg[0] == "-": if arg and arg[0] == '-':
lines = ["unrecognized arguments: %s" % (" ".join(unrecognized))] lines = ['unrecognized arguments: %s' % (' '.join(unrecognized))]
for k, v in sorted(self.extra_info.items()): for k, v in sorted(self.extra_info.items()):
lines.append(f" {k}: {v}") lines.append(f' {k}: {v}')
self.error("\n".join(lines)) self.error('\n'.join(lines))
getattr(parsed, FILE_OR_DIR).extend(unrecognized) getattr(parsed, FILE_OR_DIR).extend(unrecognized)
return parsed return parsed
@ -452,8 +454,8 @@ class MyOptionParser(argparse.ArgumentParser):
# Backport of https://github.com/python/cpython/pull/14316 so we can # Backport of https://github.com/python/cpython/pull/14316 so we can
# disable long --argument abbreviations without breaking short flags. # disable long --argument abbreviations without breaking short flags.
def _parse_optional( def _parse_optional(
self, arg_string: str self, arg_string: str,
) -> Optional[Tuple[Optional[argparse.Action], str, Optional[str]]]: ) -> tuple[argparse.Action | None, str, str | None] | None:
if not arg_string: if not arg_string:
return None return None
if arg_string[0] not in self.prefix_chars: if arg_string[0] not in self.prefix_chars:
@ -463,26 +465,26 @@ class MyOptionParser(argparse.ArgumentParser):
return action, arg_string, None return action, arg_string, None
if len(arg_string) == 1: if len(arg_string) == 1:
return None return None
if "=" in arg_string: if '=' in arg_string:
option_string, explicit_arg = arg_string.split("=", 1) option_string, explicit_arg = arg_string.split('=', 1)
if option_string in self._option_string_actions: if option_string in self._option_string_actions:
action = self._option_string_actions[option_string] action = self._option_string_actions[option_string]
return action, option_string, explicit_arg return action, option_string, explicit_arg
if self.allow_abbrev or not arg_string.startswith("--"): if self.allow_abbrev or not arg_string.startswith('--'):
option_tuples = self._get_option_tuples(arg_string) option_tuples = self._get_option_tuples(arg_string)
if len(option_tuples) > 1: if len(option_tuples) > 1:
msg = gettext( msg = gettext(
"ambiguous option: %(option)s could match %(matches)s" 'ambiguous option: %(option)s could match %(matches)s',
) )
options = ", ".join(option for _, option, _ in option_tuples) options = ', '.join(option for _, option, _ in option_tuples)
self.error(msg % {"option": arg_string, "matches": options}) self.error(msg % {'option': arg_string, 'matches': options})
elif len(option_tuples) == 1: elif len(option_tuples) == 1:
(option_tuple,) = option_tuples (option_tuple,) = option_tuples
return option_tuple return option_tuple
if self._negative_number_matcher.match(arg_string): if self._negative_number_matcher.match(arg_string):
if not self._has_negative_number_optionals: if not self._has_negative_number_optionals:
return None return None
if " " in arg_string: if ' ' in arg_string:
return None return None
return None, arg_string, None return None, arg_string, None
@ -497,45 +499,45 @@ class DropShorterLongHelpFormatter(argparse.HelpFormatter):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
# Use more accurate terminal width. # Use more accurate terminal width.
if "width" not in kwargs: if 'width' not in kwargs:
kwargs["width"] = _pytest._io.get_terminal_width() kwargs['width'] = _pytest._io.get_terminal_width()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _format_action_invocation(self, action: argparse.Action) -> str: def _format_action_invocation(self, action: argparse.Action) -> str:
orgstr = super()._format_action_invocation(action) orgstr = super()._format_action_invocation(action)
if orgstr and orgstr[0] != "-": # only optional arguments if orgstr and orgstr[0] != '-': # only optional arguments
return orgstr return orgstr
res: Optional[str] = getattr(action, "_formatted_action_invocation", None) res: str | None = getattr(action, '_formatted_action_invocation', None)
if res: if res:
return res return res
options = orgstr.split(", ") options = orgstr.split(', ')
if len(options) == 2 and (len(options[0]) == 2 or len(options[1]) == 2): if len(options) == 2 and (len(options[0]) == 2 or len(options[1]) == 2):
# a shortcut for '-h, --help' or '--abc', '-a' # a shortcut for '-h, --help' or '--abc', '-a'
action._formatted_action_invocation = orgstr # type: ignore action._formatted_action_invocation = orgstr # type: ignore
return orgstr return orgstr
return_list = [] return_list = []
short_long: Dict[str, str] = {} short_long: dict[str, str] = {}
for option in options: for option in options:
if len(option) == 2 or option[2] == " ": if len(option) == 2 or option[2] == ' ':
continue continue
if not option.startswith("--"): if not option.startswith('--'):
raise ArgumentError( raise ArgumentError(
'long optional argument without "--": [%s]' % (option), option 'long optional argument without "--": [%s]' % (option), option,
) )
xxoption = option[2:] xxoption = option[2:]
shortened = xxoption.replace("-", "") shortened = xxoption.replace('-', '')
if shortened not in short_long or len(short_long[shortened]) < len( if shortened not in short_long or len(short_long[shortened]) < len(
xxoption xxoption,
): ):
short_long[shortened] = xxoption short_long[shortened] = xxoption
# now short_long has been filled out to the longest with dashes # now short_long has been filled out to the longest with dashes
# **and** we keep the right option ordering from add_argument # **and** we keep the right option ordering from add_argument
for option in options: for option in options:
if len(option) == 2 or option[2] == " ": if len(option) == 2 or option[2] == ' ':
return_list.append(option) return_list.append(option)
if option[2:] == short_long.get(option.replace("-", "")): if option[2:] == short_long.get(option.replace('-', '')):
return_list.append(option.replace(" ", "=", 1)) return_list.append(option.replace(' ', '=', 1))
formatted_action_invocation = ", ".join(return_list) formatted_action_invocation = ', '.join(return_list)
action._formatted_action_invocation = formatted_action_invocation # type: ignore action._formatted_action_invocation = formatted_action_invocation # type: ignore
return formatted_action_invocation return formatted_action_invocation

View file

@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
import functools import functools
import warnings
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from typing import Mapping from typing import Mapping
import warnings
import pluggy import pluggy
@ -15,19 +15,19 @@ from ..deprecated import HOOK_LEGACY_PATH_ARG
# hookname: (Path, LEGACY_PATH) # hookname: (Path, LEGACY_PATH)
imply_paths_hooks: Mapping[str, tuple[str, str]] = { imply_paths_hooks: Mapping[str, tuple[str, str]] = {
"pytest_ignore_collect": ("collection_path", "path"), 'pytest_ignore_collect': ('collection_path', 'path'),
"pytest_collect_file": ("file_path", "path"), 'pytest_collect_file': ('file_path', 'path'),
"pytest_pycollect_makemodule": ("module_path", "path"), 'pytest_pycollect_makemodule': ('module_path', 'path'),
"pytest_report_header": ("start_path", "startdir"), 'pytest_report_header': ('start_path', 'startdir'),
"pytest_report_collectionfinish": ("start_path", "startdir"), 'pytest_report_collectionfinish': ('start_path', 'startdir'),
} }
def _check_path(path: Path, fspath: LEGACY_PATH) -> None: def _check_path(path: Path, fspath: LEGACY_PATH) -> None:
if Path(fspath) != path: if Path(fspath) != path:
raise ValueError( raise ValueError(
f"Path({fspath!r}) != {path!r}\n" f'Path({fspath!r}) != {path!r}\n'
"if both path and fspath are given they need to be equal" 'if both path and fspath are given they need to be equal',
) )
@ -61,7 +61,7 @@ class PathAwareHookProxy:
if fspath_value is not None: if fspath_value is not None:
warnings.warn( warnings.warn(
HOOK_LEGACY_PATH_ARG.format( HOOK_LEGACY_PATH_ARG.format(
pylib_path_arg=fspath_var, pathlib_path_arg=path_var pylib_path_arg=fspath_var, pathlib_path_arg=path_var,
), ),
stacklevel=2, stacklevel=2,
) )

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import final from typing import final

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import os import os
from pathlib import Path
import sys import sys
from pathlib import Path
from typing import Dict from typing import Dict
from typing import Iterable from typing import Iterable
from typing import List from typing import List
@ -10,13 +12,13 @@ from typing import Tuple
from typing import Union from typing import Union
import iniconfig import iniconfig
from .exceptions import UsageError
from _pytest.outcomes import fail from _pytest.outcomes import fail
from _pytest.pathlib import absolutepath from _pytest.pathlib import absolutepath
from _pytest.pathlib import commonpath from _pytest.pathlib import commonpath
from _pytest.pathlib import safe_exists from _pytest.pathlib import safe_exists
from .exceptions import UsageError
def _parse_ini_config(path: Path) -> iniconfig.IniConfig: def _parse_ini_config(path: Path) -> iniconfig.IniConfig:
"""Parse the given generic '.ini' file using legacy IniConfig parser, returning """Parse the given generic '.ini' file using legacy IniConfig parser, returning
@ -32,52 +34,52 @@ def _parse_ini_config(path: Path) -> iniconfig.IniConfig:
def load_config_dict_from_file( def load_config_dict_from_file(
filepath: Path, filepath: Path,
) -> Optional[Dict[str, Union[str, List[str]]]]: ) -> dict[str, str | list[str]] | None:
"""Load pytest configuration from the given file path, if supported. """Load pytest configuration from the given file path, if supported.
Return None if the file does not contain valid pytest configuration. Return None if the file does not contain valid pytest configuration.
""" """
# Configuration from ini files are obtained from the [pytest] section, if present. # Configuration from ini files are obtained from the [pytest] section, if present.
if filepath.suffix == ".ini": if filepath.suffix == '.ini':
iniconfig = _parse_ini_config(filepath) iniconfig = _parse_ini_config(filepath)
if "pytest" in iniconfig: if 'pytest' in iniconfig:
return dict(iniconfig["pytest"].items()) return dict(iniconfig['pytest'].items())
else: else:
# "pytest.ini" files are always the source of configuration, even if empty. # "pytest.ini" files are always the source of configuration, even if empty.
if filepath.name == "pytest.ini": if filepath.name == 'pytest.ini':
return {} return {}
# '.cfg' files are considered if they contain a "[tool:pytest]" section. # '.cfg' files are considered if they contain a "[tool:pytest]" section.
elif filepath.suffix == ".cfg": elif filepath.suffix == '.cfg':
iniconfig = _parse_ini_config(filepath) iniconfig = _parse_ini_config(filepath)
if "tool:pytest" in iniconfig.sections: if 'tool:pytest' in iniconfig.sections:
return dict(iniconfig["tool:pytest"].items()) return dict(iniconfig['tool:pytest'].items())
elif "pytest" in iniconfig.sections: elif 'pytest' in iniconfig.sections:
# If a setup.cfg contains a "[pytest]" section, we raise a failure to indicate users that # If a setup.cfg contains a "[pytest]" section, we raise a failure to indicate users that
# plain "[pytest]" sections in setup.cfg files is no longer supported (#3086). # plain "[pytest]" sections in setup.cfg files is no longer supported (#3086).
fail(CFG_PYTEST_SECTION.format(filename="setup.cfg"), pytrace=False) fail(CFG_PYTEST_SECTION.format(filename='setup.cfg'), pytrace=False)
# '.toml' files are considered if they contain a [tool.pytest.ini_options] table. # '.toml' files are considered if they contain a [tool.pytest.ini_options] table.
elif filepath.suffix == ".toml": elif filepath.suffix == '.toml':
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
import tomllib import tomllib
else: else:
import tomli as tomllib import tomli as tomllib
toml_text = filepath.read_text(encoding="utf-8") toml_text = filepath.read_text(encoding='utf-8')
try: try:
config = tomllib.loads(toml_text) config = tomllib.loads(toml_text)
except tomllib.TOMLDecodeError as exc: except tomllib.TOMLDecodeError as exc:
raise UsageError(f"{filepath}: {exc}") from exc raise UsageError(f'{filepath}: {exc}') from exc
result = config.get("tool", {}).get("pytest", {}).get("ini_options", None) result = config.get('tool', {}).get('pytest', {}).get('ini_options', None)
if result is not None: if result is not None:
# TOML supports richer data types than ini files (strings, arrays, floats, ints, etc), # TOML supports richer data types than ini files (strings, arrays, floats, ints, etc),
# however we need to convert all scalar values to str for compatibility with the rest # however we need to convert all scalar values to str for compatibility with the rest
# of the configuration system, which expects strings only. # of the configuration system, which expects strings only.
def make_scalar(v: object) -> Union[str, List[str]]: def make_scalar(v: object) -> str | list[str]:
return v if isinstance(v, list) else str(v) return v if isinstance(v, list) else str(v)
return {k: make_scalar(v) for k, v in result.items()} return {k: make_scalar(v) for k, v in result.items()}
@ -88,27 +90,27 @@ def load_config_dict_from_file(
def locate_config( def locate_config(
invocation_dir: Path, invocation_dir: Path,
args: Iterable[Path], args: Iterable[Path],
) -> Tuple[Optional[Path], Optional[Path], Dict[str, Union[str, List[str]]]]: ) -> tuple[Path | None, Path | None, dict[str, str | list[str]]]:
"""Search in the list of arguments for a valid ini-file for pytest, """Search in the list of arguments for a valid ini-file for pytest,
and return a tuple of (rootdir, inifile, cfg-dict).""" and return a tuple of (rootdir, inifile, cfg-dict)."""
config_names = [ config_names = [
"pytest.ini", 'pytest.ini',
".pytest.ini", '.pytest.ini',
"pyproject.toml", 'pyproject.toml',
"tox.ini", 'tox.ini',
"setup.cfg", 'setup.cfg',
] ]
args = [x for x in args if not str(x).startswith("-")] args = [x for x in args if not str(x).startswith('-')]
if not args: if not args:
args = [invocation_dir] args = [invocation_dir]
found_pyproject_toml: Optional[Path] = None found_pyproject_toml: Path | None = None
for arg in args: for arg in args:
argpath = absolutepath(arg) argpath = absolutepath(arg)
for base in (argpath, *argpath.parents): for base in (argpath, *argpath.parents):
for config_name in config_names: for config_name in config_names:
p = base / config_name p = base / config_name
if p.is_file(): if p.is_file():
if p.name == "pyproject.toml" and found_pyproject_toml is None: if p.name == 'pyproject.toml' and found_pyproject_toml is None:
found_pyproject_toml = p found_pyproject_toml = p
ini_config = load_config_dict_from_file(p) ini_config = load_config_dict_from_file(p)
if ini_config is not None: if ini_config is not None:
@ -122,7 +124,7 @@ def get_common_ancestor(
invocation_dir: Path, invocation_dir: Path,
paths: Iterable[Path], paths: Iterable[Path],
) -> Path: ) -> Path:
common_ancestor: Optional[Path] = None common_ancestor: Path | None = None
for path in paths: for path in paths:
if not path.exists(): if not path.exists():
continue continue
@ -144,12 +146,12 @@ def get_common_ancestor(
return common_ancestor return common_ancestor
def get_dirs_from_args(args: Iterable[str]) -> List[Path]: def get_dirs_from_args(args: Iterable[str]) -> list[Path]:
def is_option(x: str) -> bool: def is_option(x: str) -> bool:
return x.startswith("-") return x.startswith('-')
def get_file_part_from_node_id(x: str) -> str: def get_file_part_from_node_id(x: str) -> str:
return x.split("::")[0] return x.split('::')[0]
def get_dir_from_path(path: Path) -> Path: def get_dir_from_path(path: Path) -> Path:
if path.is_dir(): if path.is_dir():
@ -166,16 +168,16 @@ def get_dirs_from_args(args: Iterable[str]) -> List[Path]:
return [get_dir_from_path(path) for path in possible_paths if safe_exists(path)] return [get_dir_from_path(path) for path in possible_paths if safe_exists(path)]
CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead." CFG_PYTEST_SECTION = '[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead.'
def determine_setup( def determine_setup(
*, *,
inifile: Optional[str], inifile: str | None,
args: Sequence[str], args: Sequence[str],
rootdir_cmd_arg: Optional[str], rootdir_cmd_arg: str | None,
invocation_dir: Path, invocation_dir: Path,
) -> Tuple[Path, Optional[Path], Dict[str, Union[str, List[str]]]]: ) -> tuple[Path, Path | None, dict[str, str | list[str]]]:
"""Determine the rootdir, inifile and ini configuration values from the """Determine the rootdir, inifile and ini configuration values from the
command line arguments. command line arguments.
@ -192,7 +194,7 @@ def determine_setup(
dirs = get_dirs_from_args(args) dirs = get_dirs_from_args(args)
if inifile: if inifile:
inipath_ = absolutepath(inifile) inipath_ = absolutepath(inifile)
inipath: Optional[Path] = inipath_ inipath: Path | None = inipath_
inicfg = load_config_dict_from_file(inipath_) or {} inicfg = load_config_dict_from_file(inipath_) or {}
if rootdir_cmd_arg is None: if rootdir_cmd_arg is None:
rootdir = inipath_.parent rootdir = inipath_.parent
@ -201,7 +203,7 @@ def determine_setup(
rootdir, inipath, inicfg = locate_config(invocation_dir, [ancestor]) rootdir, inipath, inicfg = locate_config(invocation_dir, [ancestor])
if rootdir is None and rootdir_cmd_arg is None: if rootdir is None and rootdir_cmd_arg is None:
for possible_rootdir in (ancestor, *ancestor.parents): for possible_rootdir in (ancestor, *ancestor.parents):
if (possible_rootdir / "setup.py").is_file(): if (possible_rootdir / 'setup.py').is_file():
rootdir = possible_rootdir rootdir = possible_rootdir
break break
else: else:
@ -209,7 +211,7 @@ def determine_setup(
rootdir, inipath, inicfg = locate_config(invocation_dir, dirs) rootdir, inipath, inicfg = locate_config(invocation_dir, dirs)
if rootdir is None: if rootdir is None:
rootdir = get_common_ancestor( rootdir = get_common_ancestor(
invocation_dir, [invocation_dir, ancestor] invocation_dir, [invocation_dir, ancestor],
) )
if is_fs_root(rootdir): if is_fs_root(rootdir):
rootdir = ancestor rootdir = ancestor
@ -217,7 +219,7 @@ def determine_setup(
rootdir = absolutepath(os.path.expandvars(rootdir_cmd_arg)) rootdir = absolutepath(os.path.expandvars(rootdir_cmd_arg))
if not rootdir.is_dir(): if not rootdir.is_dir():
raise UsageError( raise UsageError(
f"Directory '{rootdir}' not found. Check your '--rootdir' option." f"Directory '{rootdir}' not found. Check your '--rootdir' option.",
) )
assert rootdir is not None assert rootdir is not None
return rootdir, inipath, inicfg or {} return rootdir, inipath, inicfg or {}

View file

@ -1,9 +1,12 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Interactive debugging with PDB, the Python Debugger.""" """Interactive debugging with PDB, the Python Debugger."""
from __future__ import annotations
import argparse import argparse
import functools import functools
import sys import sys
import types import types
import unittest
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Generator from typing import Generator
@ -13,7 +16,6 @@ from typing import Tuple
from typing import Type from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union from typing import Union
import unittest
from _pytest import outcomes from _pytest import outcomes
from _pytest._code import ExceptionInfo from _pytest._code import ExceptionInfo
@ -32,51 +34,51 @@ if TYPE_CHECKING:
from _pytest.runner import CallInfo from _pytest.runner import CallInfo
def _validate_usepdb_cls(value: str) -> Tuple[str, str]: def _validate_usepdb_cls(value: str) -> tuple[str, str]:
"""Validate syntax of --pdbcls option.""" """Validate syntax of --pdbcls option."""
try: try:
modname, classname = value.split(":") modname, classname = value.split(':')
except ValueError as e: except ValueError as e:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
f"{value!r} is not in the format 'modname:classname'" f"{value!r} is not in the format 'modname:classname'",
) from e ) from e
return (modname, classname) return (modname, classname)
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general") group = parser.getgroup('general')
group._addoption( group._addoption(
"--pdb", '--pdb',
dest="usepdb", dest='usepdb',
action="store_true", action='store_true',
help="Start the interactive Python debugger on errors or KeyboardInterrupt", help='Start the interactive Python debugger on errors or KeyboardInterrupt',
) )
group._addoption( group._addoption(
"--pdbcls", '--pdbcls',
dest="usepdb_cls", dest='usepdb_cls',
metavar="modulename:classname", metavar='modulename:classname',
type=_validate_usepdb_cls, type=_validate_usepdb_cls,
help="Specify a custom interactive Python debugger for use with --pdb." help='Specify a custom interactive Python debugger for use with --pdb.'
"For example: --pdbcls=IPython.terminal.debugger:TerminalPdb", 'For example: --pdbcls=IPython.terminal.debugger:TerminalPdb',
) )
group._addoption( group._addoption(
"--trace", '--trace',
dest="trace", dest='trace',
action="store_true", action='store_true',
help="Immediately break when running each test", help='Immediately break when running each test',
) )
def pytest_configure(config: Config) -> None: def pytest_configure(config: Config) -> None:
import pdb import pdb
if config.getvalue("trace"): if config.getvalue('trace'):
config.pluginmanager.register(PdbTrace(), "pdbtrace") config.pluginmanager.register(PdbTrace(), 'pdbtrace')
if config.getvalue("usepdb"): if config.getvalue('usepdb'):
config.pluginmanager.register(PdbInvoke(), "pdbinvoke") config.pluginmanager.register(PdbInvoke(), 'pdbinvoke')
pytestPDB._saved.append( pytestPDB._saved.append(
(pdb.set_trace, pytestPDB._pluginmanager, pytestPDB._config) (pdb.set_trace, pytestPDB._pluginmanager, pytestPDB._config),
) )
pdb.set_trace = pytestPDB.set_trace pdb.set_trace = pytestPDB.set_trace
pytestPDB._pluginmanager = config.pluginmanager pytestPDB._pluginmanager = config.pluginmanager
@ -97,29 +99,29 @@ def pytest_configure(config: Config) -> None:
class pytestPDB: class pytestPDB:
"""Pseudo PDB that defers to the real pdb.""" """Pseudo PDB that defers to the real pdb."""
_pluginmanager: Optional[PytestPluginManager] = None _pluginmanager: PytestPluginManager | None = None
_config: Optional[Config] = None _config: Config | None = None
_saved: List[ _saved: list[
Tuple[Callable[..., None], Optional[PytestPluginManager], Optional[Config]] tuple[Callable[..., None], PytestPluginManager | None, Config | None]
] = [] ] = []
_recursive_debug = 0 _recursive_debug = 0
_wrapped_pdb_cls: Optional[Tuple[Type[Any], Type[Any]]] = None _wrapped_pdb_cls: tuple[type[Any], type[Any]] | None = None
@classmethod @classmethod
def _is_capturing(cls, capman: Optional["CaptureManager"]) -> Union[str, bool]: def _is_capturing(cls, capman: CaptureManager | None) -> str | bool:
if capman: if capman:
return capman.is_capturing() return capman.is_capturing()
return False return False
@classmethod @classmethod
def _import_pdb_cls(cls, capman: Optional["CaptureManager"]): def _import_pdb_cls(cls, capman: CaptureManager | None):
if not cls._config: if not cls._config:
import pdb import pdb
# Happens when using pytest.set_trace outside of a test. # Happens when using pytest.set_trace outside of a test.
return pdb.Pdb return pdb.Pdb
usepdb_cls = cls._config.getvalue("usepdb_cls") usepdb_cls = cls._config.getvalue('usepdb_cls')
if cls._wrapped_pdb_cls and cls._wrapped_pdb_cls[0] == usepdb_cls: if cls._wrapped_pdb_cls and cls._wrapped_pdb_cls[0] == usepdb_cls:
return cls._wrapped_pdb_cls[1] return cls._wrapped_pdb_cls[1]
@ -132,14 +134,14 @@ class pytestPDB:
mod = sys.modules[modname] mod = sys.modules[modname]
# Handle --pdbcls=pdb:pdb.Pdb (useful e.g. with pdbpp). # Handle --pdbcls=pdb:pdb.Pdb (useful e.g. with pdbpp).
parts = classname.split(".") parts = classname.split('.')
pdb_cls = getattr(mod, parts[0]) pdb_cls = getattr(mod, parts[0])
for part in parts[1:]: for part in parts[1:]:
pdb_cls = getattr(pdb_cls, part) pdb_cls = getattr(pdb_cls, part)
except Exception as exc: except Exception as exc:
value = ":".join((modname, classname)) value = ':'.join((modname, classname))
raise UsageError( raise UsageError(
f"--pdbcls: could not import {value!r}: {exc}" f'--pdbcls: could not import {value!r}: {exc}',
) from exc ) from exc
else: else:
import pdb import pdb
@ -151,7 +153,7 @@ class pytestPDB:
return wrapped_cls return wrapped_cls
@classmethod @classmethod
def _get_pdb_wrapper_class(cls, pdb_cls, capman: Optional["CaptureManager"]): def _get_pdb_wrapper_class(cls, pdb_cls, capman: CaptureManager | None):
import _pytest.config import _pytest.config
# Type ignored because mypy doesn't support "dynamic" # Type ignored because mypy doesn't support "dynamic"
@ -176,18 +178,18 @@ class pytestPDB:
capman = self._pytest_capman capman = self._pytest_capman
capturing = pytestPDB._is_capturing(capman) capturing = pytestPDB._is_capturing(capman)
if capturing: if capturing:
if capturing == "global": if capturing == 'global':
tw.sep(">", "PDB continue (IO-capturing resumed)") tw.sep('>', 'PDB continue (IO-capturing resumed)')
else: else:
tw.sep( tw.sep(
">", '>',
"PDB continue (IO-capturing resumed for %s)" 'PDB continue (IO-capturing resumed for %s)'
% capturing, % capturing,
) )
assert capman is not None assert capman is not None
capman.resume() capman.resume()
else: else:
tw.sep(">", "PDB continue") tw.sep('>', 'PDB continue')
assert cls._pluginmanager is not None assert cls._pluginmanager is not None
cls._pluginmanager.hook.pytest_leave_pdb(config=cls._config, pdb=self) cls._pluginmanager.hook.pytest_leave_pdb(config=cls._config, pdb=self)
self._continued = True self._continued = True
@ -205,7 +207,7 @@ class pytestPDB:
ret = super().do_quit(arg) ret = super().do_quit(arg)
if cls._recursive_debug == 0: if cls._recursive_debug == 0:
outcomes.exit("Quitting debugger") outcomes.exit('Quitting debugger')
return ret return ret
@ -231,7 +233,7 @@ class pytestPDB:
if f is None: if f is None:
# Find last non-hidden frame. # Find last non-hidden frame.
i = max(0, len(stack) - 1) i = max(0, len(stack) - 1)
while i and stack[i][0].f_locals.get("__tracebackhide__", False): while i and stack[i][0].f_locals.get('__tracebackhide__', False):
i -= 1 i -= 1
return stack, i return stack, i
@ -243,9 +245,9 @@ class pytestPDB:
import _pytest.config import _pytest.config
if cls._pluginmanager is None: if cls._pluginmanager is None:
capman: Optional[CaptureManager] = None capman: CaptureManager | None = None
else: else:
capman = cls._pluginmanager.getplugin("capturemanager") capman = cls._pluginmanager.getplugin('capturemanager')
if capman: if capman:
capman.suspend(in_=True) capman.suspend(in_=True)
@ -255,20 +257,20 @@ class pytestPDB:
if cls._recursive_debug == 0: if cls._recursive_debug == 0:
# Handle header similar to pdb.set_trace in py37+. # Handle header similar to pdb.set_trace in py37+.
header = kwargs.pop("header", None) header = kwargs.pop('header', None)
if header is not None: if header is not None:
tw.sep(">", header) tw.sep('>', header)
else: else:
capturing = cls._is_capturing(capman) capturing = cls._is_capturing(capman)
if capturing == "global": if capturing == 'global':
tw.sep(">", f"PDB {method} (IO-capturing turned off)") tw.sep('>', f'PDB {method} (IO-capturing turned off)')
elif capturing: elif capturing:
tw.sep( tw.sep(
">", '>',
f"PDB {method} (IO-capturing turned off for {capturing})", f'PDB {method} (IO-capturing turned off for {capturing})',
) )
else: else:
tw.sep(">", f"PDB {method}") tw.sep('>', f'PDB {method}')
_pdb = cls._import_pdb_cls(capman)(**kwargs) _pdb = cls._import_pdb_cls(capman)(**kwargs)
@ -280,15 +282,15 @@ class pytestPDB:
def set_trace(cls, *args, **kwargs) -> None: def set_trace(cls, *args, **kwargs) -> None:
"""Invoke debugging via ``Pdb.set_trace``, dropping any IO capturing.""" """Invoke debugging via ``Pdb.set_trace``, dropping any IO capturing."""
frame = sys._getframe().f_back frame = sys._getframe().f_back
_pdb = cls._init_pdb("set_trace", *args, **kwargs) _pdb = cls._init_pdb('set_trace', *args, **kwargs)
_pdb.set_trace(frame) _pdb.set_trace(frame)
class PdbInvoke: class PdbInvoke:
def pytest_exception_interact( def pytest_exception_interact(
self, node: Node, call: "CallInfo[Any]", report: BaseReport self, node: Node, call: CallInfo[Any], report: BaseReport,
) -> None: ) -> None:
capman = node.config.pluginmanager.getplugin("capturemanager") capman = node.config.pluginmanager.getplugin('capturemanager')
if capman: if capman:
capman.suspend_global_capture(in_=True) capman.suspend_global_capture(in_=True)
out, err = capman.read_global_capture() out, err = capman.read_global_capture()
@ -316,7 +318,7 @@ def wrap_pytest_function_for_tracing(pyfuncitem):
wrapper which actually enters pdb before calling the python function wrapper which actually enters pdb before calling the python function
itself, effectively leaving the user in the pdb prompt in the first itself, effectively leaving the user in the pdb prompt in the first
statement of the function.""" statement of the function."""
_pdb = pytestPDB._init_pdb("runcall") _pdb = pytestPDB._init_pdb('runcall')
testfunction = pyfuncitem.obj testfunction = pyfuncitem.obj
# we can't just return `partial(pdb.runcall, testfunction)` because (on # we can't just return `partial(pdb.runcall, testfunction)` because (on
@ -333,35 +335,35 @@ def wrap_pytest_function_for_tracing(pyfuncitem):
def maybe_wrap_pytest_function_for_tracing(pyfuncitem): def maybe_wrap_pytest_function_for_tracing(pyfuncitem):
"""Wrap the given pytestfunct item for tracing support if --trace was given in """Wrap the given pytestfunct item for tracing support if --trace was given in
the command line.""" the command line."""
if pyfuncitem.config.getvalue("trace"): if pyfuncitem.config.getvalue('trace'):
wrap_pytest_function_for_tracing(pyfuncitem) wrap_pytest_function_for_tracing(pyfuncitem)
def _enter_pdb( def _enter_pdb(
node: Node, excinfo: ExceptionInfo[BaseException], rep: BaseReport node: Node, excinfo: ExceptionInfo[BaseException], rep: BaseReport,
) -> BaseReport: ) -> BaseReport:
# XXX we re-use the TerminalReporter's terminalwriter # XXX we re-use the TerminalReporter's terminalwriter
# because this seems to avoid some encoding related troubles # because this seems to avoid some encoding related troubles
# for not completely clear reasons. # for not completely clear reasons.
tw = node.config.pluginmanager.getplugin("terminalreporter")._tw tw = node.config.pluginmanager.getplugin('terminalreporter')._tw
tw.line() tw.line()
showcapture = node.config.option.showcapture showcapture = node.config.option.showcapture
for sectionname, content in ( for sectionname, content in (
("stdout", rep.capstdout), ('stdout', rep.capstdout),
("stderr", rep.capstderr), ('stderr', rep.capstderr),
("log", rep.caplog), ('log', rep.caplog),
): ):
if showcapture in (sectionname, "all") and content: if showcapture in (sectionname, 'all') and content:
tw.sep(">", "captured " + sectionname) tw.sep('>', 'captured ' + sectionname)
if content[-1:] == "\n": if content[-1:] == '\n':
content = content[:-1] content = content[:-1]
tw.line(content) tw.line(content)
tw.sep(">", "traceback") tw.sep('>', 'traceback')
rep.toterminal(tw) rep.toterminal(tw)
tw.sep(">", "entering PDB") tw.sep('>', 'entering PDB')
tb = _postmortem_traceback(excinfo) tb = _postmortem_traceback(excinfo)
rep._pdbshown = True # type: ignore[attr-defined] rep._pdbshown = True # type: ignore[attr-defined]
post_mortem(tb) post_mortem(tb)
@ -386,8 +388,8 @@ def _postmortem_traceback(excinfo: ExceptionInfo[BaseException]) -> types.Traceb
def post_mortem(t: types.TracebackType) -> None: def post_mortem(t: types.TracebackType) -> None:
p = pytestPDB._init_pdb("post_mortem") p = pytestPDB._init_pdb('post_mortem')
p.reset() p.reset()
p.interaction(None, t) p.interaction(None, t)
if p.quitting: if p.quitting:
outcomes.exit("Quitting debugger") outcomes.exit('Quitting debugger')

View file

@ -8,6 +8,7 @@ All constants defined in this module should be either instances of
:class:`PytestWarning`, or :class:`UnformattedWarning` :class:`PytestWarning`, or :class:`UnformattedWarning`
in case of warnings which need to format their messages. in case of warnings which need to format their messages.
""" """
from __future__ import annotations
from warnings import warn from warnings import warn
@ -19,50 +20,50 @@ from _pytest.warning_types import UnformattedWarning
# set of plugins which have been integrated into the core; we use this list to ignore # set of plugins which have been integrated into the core; we use this list to ignore
# them during registration to avoid conflicts # them during registration to avoid conflicts
DEPRECATED_EXTERNAL_PLUGINS = { DEPRECATED_EXTERNAL_PLUGINS = {
"pytest_catchlog", 'pytest_catchlog',
"pytest_capturelog", 'pytest_capturelog',
"pytest_faulthandler", 'pytest_faulthandler',
} }
# This can be* removed pytest 8, but it's harmless and common, so no rush to remove. # This can be* removed pytest 8, but it's harmless and common, so no rush to remove.
# * If you're in the future: "could have been". # * If you're in the future: "could have been".
YIELD_FIXTURE = PytestDeprecationWarning( YIELD_FIXTURE = PytestDeprecationWarning(
"@pytest.yield_fixture is deprecated.\n" '@pytest.yield_fixture is deprecated.\n'
"Use @pytest.fixture instead; they are the same." 'Use @pytest.fixture instead; they are the same.',
) )
# This deprecation is never really meant to be removed. # This deprecation is never really meant to be removed.
PRIVATE = PytestDeprecationWarning("A private pytest class or function was used.") PRIVATE = PytestDeprecationWarning('A private pytest class or function was used.')
HOOK_LEGACY_PATH_ARG = UnformattedWarning( HOOK_LEGACY_PATH_ARG = UnformattedWarning(
PytestRemovedIn9Warning, PytestRemovedIn9Warning,
"The ({pylib_path_arg}: py.path.local) argument is deprecated, please use ({pathlib_path_arg}: pathlib.Path)\n" 'The ({pylib_path_arg}: py.path.local) argument is deprecated, please use ({pathlib_path_arg}: pathlib.Path)\n'
"see https://docs.pytest.org/en/latest/deprecations.html" 'see https://docs.pytest.org/en/latest/deprecations.html'
"#py-path-local-arguments-for-hooks-replaced-with-pathlib-path", '#py-path-local-arguments-for-hooks-replaced-with-pathlib-path',
) )
NODE_CTOR_FSPATH_ARG = UnformattedWarning( NODE_CTOR_FSPATH_ARG = UnformattedWarning(
PytestRemovedIn9Warning, PytestRemovedIn9Warning,
"The (fspath: py.path.local) argument to {node_type_name} is deprecated. " 'The (fspath: py.path.local) argument to {node_type_name} is deprecated. '
"Please use the (path: pathlib.Path) argument instead.\n" 'Please use the (path: pathlib.Path) argument instead.\n'
"See https://docs.pytest.org/en/latest/deprecations.html" 'See https://docs.pytest.org/en/latest/deprecations.html'
"#fspath-argument-for-node-constructors-replaced-with-pathlib-path", '#fspath-argument-for-node-constructors-replaced-with-pathlib-path',
) )
HOOK_LEGACY_MARKING = UnformattedWarning( HOOK_LEGACY_MARKING = UnformattedWarning(
PytestDeprecationWarning, PytestDeprecationWarning,
"The hook{type} {fullname} uses old-style configuration options (marks or attributes).\n" 'The hook{type} {fullname} uses old-style configuration options (marks or attributes).\n'
"Please use the pytest.hook{type}({hook_opts}) decorator instead\n" 'Please use the pytest.hook{type}({hook_opts}) decorator instead\n'
" to configure the hooks.\n" ' to configure the hooks.\n'
" See https://docs.pytest.org/en/latest/deprecations.html" ' See https://docs.pytest.org/en/latest/deprecations.html'
"#configuring-hook-specs-impls-using-markers", '#configuring-hook-specs-impls-using-markers',
) )
MARKED_FIXTURE = PytestRemovedIn9Warning( MARKED_FIXTURE = PytestRemovedIn9Warning(
"Marks applied to fixtures have no effect\n" 'Marks applied to fixtures have no effect\n'
"See docs: https://docs.pytest.org/en/stable/deprecations.html#applying-a-mark-to-a-fixture-function" 'See docs: https://docs.pytest.org/en/stable/deprecations.html#applying-a-mark-to-a-fixture-function',
) )
# You want to make some `__init__` or function "private". # You want to make some `__init__` or function "private".

View file

@ -1,15 +1,18 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Discover and run doctests in modules and test files.""" """Discover and run doctests in modules and test files."""
from __future__ import annotations
import bdb import bdb
from contextlib import contextmanager
import functools import functools
import inspect import inspect
import os import os
from pathlib import Path
import platform import platform
import sys import sys
import traceback import traceback
import types import types
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
@ -23,7 +26,6 @@ from typing import Tuple
from typing import Type from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union from typing import Union
import warnings
from _pytest import outcomes from _pytest import outcomes
from _pytest._code.code import ExceptionInfo from _pytest._code.code import ExceptionInfo
@ -49,11 +51,11 @@ if TYPE_CHECKING:
import doctest import doctest
from typing import Self from typing import Self
DOCTEST_REPORT_CHOICE_NONE = "none" DOCTEST_REPORT_CHOICE_NONE = 'none'
DOCTEST_REPORT_CHOICE_CDIFF = "cdiff" DOCTEST_REPORT_CHOICE_CDIFF = 'cdiff'
DOCTEST_REPORT_CHOICE_NDIFF = "ndiff" DOCTEST_REPORT_CHOICE_NDIFF = 'ndiff'
DOCTEST_REPORT_CHOICE_UDIFF = "udiff" DOCTEST_REPORT_CHOICE_UDIFF = 'udiff'
DOCTEST_REPORT_CHOICE_ONLY_FIRST_FAILURE = "only_first_failure" DOCTEST_REPORT_CHOICE_ONLY_FIRST_FAILURE = 'only_first_failure'
DOCTEST_REPORT_CHOICES = ( DOCTEST_REPORT_CHOICES = (
DOCTEST_REPORT_CHOICE_NONE, DOCTEST_REPORT_CHOICE_NONE,
@ -66,56 +68,56 @@ DOCTEST_REPORT_CHOICES = (
# Lazy definition of runner class # Lazy definition of runner class
RUNNER_CLASS = None RUNNER_CLASS = None
# Lazy definition of output checker class # Lazy definition of output checker class
CHECKER_CLASS: Optional[Type["doctest.OutputChecker"]] = None CHECKER_CLASS: type[doctest.OutputChecker] | None = None
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
parser.addini( parser.addini(
"doctest_optionflags", 'doctest_optionflags',
"Option flags for doctests", 'Option flags for doctests',
type="args", type='args',
default=["ELLIPSIS"], default=['ELLIPSIS'],
) )
parser.addini( parser.addini(
"doctest_encoding", "Encoding used for doctest files", default="utf-8" 'doctest_encoding', 'Encoding used for doctest files', default='utf-8',
) )
group = parser.getgroup("collect") group = parser.getgroup('collect')
group.addoption( group.addoption(
"--doctest-modules", '--doctest-modules',
action="store_true", action='store_true',
default=False, default=False,
help="Run doctests in all .py modules", help='Run doctests in all .py modules',
dest="doctestmodules", dest='doctestmodules',
) )
group.addoption( group.addoption(
"--doctest-report", '--doctest-report',
type=str.lower, type=str.lower,
default="udiff", default='udiff',
help="Choose another output format for diffs on doctest failure", help='Choose another output format for diffs on doctest failure',
choices=DOCTEST_REPORT_CHOICES, choices=DOCTEST_REPORT_CHOICES,
dest="doctestreport", dest='doctestreport',
) )
group.addoption( group.addoption(
"--doctest-glob", '--doctest-glob',
action="append", action='append',
default=[], default=[],
metavar="pat", metavar='pat',
help="Doctests file matching pattern, default: test*.txt", help='Doctests file matching pattern, default: test*.txt',
dest="doctestglob", dest='doctestglob',
) )
group.addoption( group.addoption(
"--doctest-ignore-import-errors", '--doctest-ignore-import-errors',
action="store_true", action='store_true',
default=False, default=False,
help="Ignore doctest collection errors", help='Ignore doctest collection errors',
dest="doctest_ignore_import_errors", dest='doctest_ignore_import_errors',
) )
group.addoption( group.addoption(
"--doctest-continue-on-failure", '--doctest-continue-on-failure',
action="store_true", action='store_true',
default=False, default=False,
help="For a given doctest, continue to run after the first failure", help='For a given doctest, continue to run after the first failure',
dest="doctest_continue_on_failure", dest='doctest_continue_on_failure',
) )
@ -128,11 +130,11 @@ def pytest_unconfigure() -> None:
def pytest_collect_file( def pytest_collect_file(
file_path: Path, file_path: Path,
parent: Collector, parent: Collector,
) -> Optional[Union["DoctestModule", "DoctestTextfile"]]: ) -> DoctestModule | DoctestTextfile | None:
config = parent.config config = parent.config
if file_path.suffix == ".py": if file_path.suffix == '.py':
if config.option.doctestmodules and not any( if config.option.doctestmodules and not any(
(_is_setup_py(file_path), _is_main_py(file_path)) (_is_setup_py(file_path), _is_main_py(file_path)),
): ):
return DoctestModule.from_parent(parent, path=file_path) return DoctestModule.from_parent(parent, path=file_path)
elif _is_doctest(config, file_path, parent): elif _is_doctest(config, file_path, parent):
@ -141,26 +143,26 @@ def pytest_collect_file(
def _is_setup_py(path: Path) -> bool: def _is_setup_py(path: Path) -> bool:
if path.name != "setup.py": if path.name != 'setup.py':
return False return False
contents = path.read_bytes() contents = path.read_bytes()
return b"setuptools" in contents or b"distutils" in contents return b'setuptools' in contents or b'distutils' in contents
def _is_doctest(config: Config, path: Path, parent: Collector) -> bool: def _is_doctest(config: Config, path: Path, parent: Collector) -> bool:
if path.suffix in (".txt", ".rst") and parent.session.isinitpath(path): if path.suffix in ('.txt', '.rst') and parent.session.isinitpath(path):
return True return True
globs = config.getoption("doctestglob") or ["test*.txt"] globs = config.getoption('doctestglob') or ['test*.txt']
return any(fnmatch_ex(glob, path) for glob in globs) return any(fnmatch_ex(glob, path) for glob in globs)
def _is_main_py(path: Path) -> bool: def _is_main_py(path: Path) -> bool:
return path.name == "__main__.py" return path.name == '__main__.py'
class ReprFailDoctest(TerminalRepr): class ReprFailDoctest(TerminalRepr):
def __init__( def __init__(
self, reprlocation_lines: Sequence[Tuple[ReprFileLocation, Sequence[str]]] self, reprlocation_lines: Sequence[tuple[ReprFileLocation, Sequence[str]]],
) -> None: ) -> None:
self.reprlocation_lines = reprlocation_lines self.reprlocation_lines = reprlocation_lines
@ -172,12 +174,12 @@ class ReprFailDoctest(TerminalRepr):
class MultipleDoctestFailures(Exception): class MultipleDoctestFailures(Exception):
def __init__(self, failures: Sequence["doctest.DocTestFailure"]) -> None: def __init__(self, failures: Sequence[doctest.DocTestFailure]) -> None:
super().__init__() super().__init__()
self.failures = failures self.failures = failures
def _init_runner_class() -> Type["doctest.DocTestRunner"]: def _init_runner_class() -> type[doctest.DocTestRunner]:
import doctest import doctest
class PytestDoctestRunner(doctest.DebugRunner): class PytestDoctestRunner(doctest.DebugRunner):
@ -189,8 +191,8 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def __init__( def __init__(
self, self,
checker: Optional["doctest.OutputChecker"] = None, checker: doctest.OutputChecker | None = None,
verbose: Optional[bool] = None, verbose: bool | None = None,
optionflags: int = 0, optionflags: int = 0,
continue_on_failure: bool = True, continue_on_failure: bool = True,
) -> None: ) -> None:
@ -200,8 +202,8 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def report_failure( def report_failure(
self, self,
out, out,
test: "doctest.DocTest", test: doctest.DocTest,
example: "doctest.Example", example: doctest.Example,
got: str, got: str,
) -> None: ) -> None:
failure = doctest.DocTestFailure(test, example, got) failure = doctest.DocTestFailure(test, example, got)
@ -213,14 +215,14 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def report_unexpected_exception( def report_unexpected_exception(
self, self,
out, out,
test: "doctest.DocTest", test: doctest.DocTest,
example: "doctest.Example", example: doctest.Example,
exc_info: Tuple[Type[BaseException], BaseException, types.TracebackType], exc_info: tuple[type[BaseException], BaseException, types.TracebackType],
) -> None: ) -> None:
if isinstance(exc_info[1], OutcomeException): if isinstance(exc_info[1], OutcomeException):
raise exc_info[1] raise exc_info[1]
if isinstance(exc_info[1], bdb.BdbQuit): if isinstance(exc_info[1], bdb.BdbQuit):
outcomes.exit("Quitting debugger") outcomes.exit('Quitting debugger')
failure = doctest.UnexpectedException(test, example, exc_info) failure = doctest.UnexpectedException(test, example, exc_info)
if self.continue_on_failure: if self.continue_on_failure:
out.append(failure) out.append(failure)
@ -231,11 +233,11 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def _get_runner( def _get_runner(
checker: Optional["doctest.OutputChecker"] = None, checker: doctest.OutputChecker | None = None,
verbose: Optional[bool] = None, verbose: bool | None = None,
optionflags: int = 0, optionflags: int = 0,
continue_on_failure: bool = True, continue_on_failure: bool = True,
) -> "doctest.DocTestRunner": ) -> doctest.DocTestRunner:
# We need this in order to do a lazy import on doctest # We need this in order to do a lazy import on doctest
global RUNNER_CLASS global RUNNER_CLASS
if RUNNER_CLASS is None: if RUNNER_CLASS is None:
@ -254,9 +256,9 @@ class DoctestItem(Item):
def __init__( def __init__(
self, self,
name: str, name: str,
parent: "Union[DoctestTextfile, DoctestModule]", parent: Union[DoctestTextfile, DoctestModule],
runner: "doctest.DocTestRunner", runner: doctest.DocTestRunner,
dtest: "doctest.DocTest", dtest: doctest.DocTest,
) -> None: ) -> None:
super().__init__(name, parent) super().__init__(name, parent)
self.runner = runner self.runner = runner
@ -273,31 +275,31 @@ class DoctestItem(Item):
@classmethod @classmethod
def from_parent( # type: ignore[override] def from_parent( # type: ignore[override]
cls, cls,
parent: "Union[DoctestTextfile, DoctestModule]", parent: Union[DoctestTextfile, DoctestModule],
*, *,
name: str, name: str,
runner: "doctest.DocTestRunner", runner: doctest.DocTestRunner,
dtest: "doctest.DocTest", dtest: doctest.DocTest,
) -> "Self": ) -> Self:
# incompatible signature due to imposed limits on subclass # incompatible signature due to imposed limits on subclass
"""The public named constructor.""" """The public named constructor."""
return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest) return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest)
def _initrequest(self) -> None: def _initrequest(self) -> None:
self.funcargs: Dict[str, object] = {} self.funcargs: dict[str, object] = {}
self._request = TopRequest(self, _ispytest=True) # type: ignore[arg-type] self._request = TopRequest(self, _ispytest=True) # type: ignore[arg-type]
def setup(self) -> None: def setup(self) -> None:
self._request._fillfixtures() self._request._fillfixtures()
globs = dict(getfixture=self._request.getfixturevalue) globs = dict(getfixture=self._request.getfixturevalue)
for name, value in self._request.getfixturevalue("doctest_namespace").items(): for name, value in self._request.getfixturevalue('doctest_namespace').items():
globs[name] = value globs[name] = value
self.dtest.globs.update(globs) self.dtest.globs.update(globs)
def runtest(self) -> None: def runtest(self) -> None:
_check_all_skipped(self.dtest) _check_all_skipped(self.dtest)
self._disable_output_capturing_for_darwin() self._disable_output_capturing_for_darwin()
failures: List["doctest.DocTestFailure"] = [] failures: list[doctest.DocTestFailure] = []
# Type ignored because we change the type of `out` from what # Type ignored because we change the type of `out` from what
# doctest expects. # doctest expects.
self.runner.run(self.dtest, out=failures) # type: ignore[arg-type] self.runner.run(self.dtest, out=failures) # type: ignore[arg-type]
@ -306,9 +308,9 @@ class DoctestItem(Item):
def _disable_output_capturing_for_darwin(self) -> None: def _disable_output_capturing_for_darwin(self) -> None:
"""Disable output capturing. Otherwise, stdout is lost to doctest (#985).""" """Disable output capturing. Otherwise, stdout is lost to doctest (#985)."""
if platform.system() != "Darwin": if platform.system() != 'Darwin':
return return
capman = self.config.pluginmanager.getplugin("capturemanager") capman = self.config.pluginmanager.getplugin('capturemanager')
if capman: if capman:
capman.suspend_global_capture(in_=True) capman.suspend_global_capture(in_=True)
out, err = capman.read_global_capture() out, err = capman.read_global_capture()
@ -319,14 +321,14 @@ class DoctestItem(Item):
def repr_failure( # type: ignore[override] def repr_failure( # type: ignore[override]
self, self,
excinfo: ExceptionInfo[BaseException], excinfo: ExceptionInfo[BaseException],
) -> Union[str, TerminalRepr]: ) -> str | TerminalRepr:
import doctest import doctest
failures: Optional[ failures: None | (
Sequence[Union[doctest.DocTestFailure, doctest.UnexpectedException]] Sequence[doctest.DocTestFailure | doctest.UnexpectedException]
] = None ) = None
if isinstance( if isinstance(
excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException) excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException),
): ):
failures = [excinfo.value] failures = [excinfo.value]
elif isinstance(excinfo.value, MultipleDoctestFailures): elif isinstance(excinfo.value, MultipleDoctestFailures):
@ -348,43 +350,43 @@ class DoctestItem(Item):
# TODO: ReprFileLocation doesn't expect a None lineno. # TODO: ReprFileLocation doesn't expect a None lineno.
reprlocation = ReprFileLocation(filename, lineno, message) # type: ignore[arg-type] reprlocation = ReprFileLocation(filename, lineno, message) # type: ignore[arg-type]
checker = _get_checker() checker = _get_checker()
report_choice = _get_report_choice(self.config.getoption("doctestreport")) report_choice = _get_report_choice(self.config.getoption('doctestreport'))
if lineno is not None: if lineno is not None:
assert failure.test.docstring is not None assert failure.test.docstring is not None
lines = failure.test.docstring.splitlines(False) lines = failure.test.docstring.splitlines(False)
# add line numbers to the left of the error message # add line numbers to the left of the error message
assert test.lineno is not None assert test.lineno is not None
lines = [ lines = [
"%03d %s" % (i + test.lineno + 1, x) for (i, x) in enumerate(lines) '%03d %s' % (i + test.lineno + 1, x) for (i, x) in enumerate(lines)
] ]
# trim docstring error lines to 10 # trim docstring error lines to 10
lines = lines[max(example.lineno - 9, 0) : example.lineno + 1] lines = lines[max(example.lineno - 9, 0): example.lineno + 1]
else: else:
lines = [ lines = [
"EXAMPLE LOCATION UNKNOWN, not showing all tests of that example" 'EXAMPLE LOCATION UNKNOWN, not showing all tests of that example',
] ]
indent = ">>>" indent = '>>>'
for line in example.source.splitlines(): for line in example.source.splitlines():
lines.append(f"??? {indent} {line}") lines.append(f'??? {indent} {line}')
indent = "..." indent = '...'
if isinstance(failure, doctest.DocTestFailure): if isinstance(failure, doctest.DocTestFailure):
lines += checker.output_difference( lines += checker.output_difference(
example, failure.got, report_choice example, failure.got, report_choice,
).split("\n") ).split('\n')
else: else:
inner_excinfo = ExceptionInfo.from_exc_info(failure.exc_info) inner_excinfo = ExceptionInfo.from_exc_info(failure.exc_info)
lines += ["UNEXPECTED EXCEPTION: %s" % repr(inner_excinfo.value)] lines += ['UNEXPECTED EXCEPTION: %s' % repr(inner_excinfo.value)]
lines += [ lines += [
x.strip("\n") for x in traceback.format_exception(*failure.exc_info) x.strip('\n') for x in traceback.format_exception(*failure.exc_info)
] ]
reprlocation_lines.append((reprlocation, lines)) reprlocation_lines.append((reprlocation, lines))
return ReprFailDoctest(reprlocation_lines) return ReprFailDoctest(reprlocation_lines)
def reportinfo(self) -> Tuple[Union["os.PathLike[str]", str], Optional[int], str]: def reportinfo(self) -> tuple[os.PathLike[str] | str, int | None, str]:
return self.path, self.dtest.lineno, "[doctest] %s" % self.name return self.path, self.dtest.lineno, '[doctest] %s' % self.name
def _get_flag_lookup() -> Dict[str, int]: def _get_flag_lookup() -> dict[str, int]:
import doctest import doctest
return dict( return dict(
@ -401,7 +403,7 @@ def _get_flag_lookup() -> Dict[str, int]:
def get_optionflags(config: Config) -> int: def get_optionflags(config: Config) -> int:
optionflags_str = config.getini("doctest_optionflags") optionflags_str = config.getini('doctest_optionflags')
flag_lookup_table = _get_flag_lookup() flag_lookup_table = _get_flag_lookup()
flag_acc = 0 flag_acc = 0
for flag in optionflags_str: for flag in optionflags_str:
@ -410,11 +412,11 @@ def get_optionflags(config: Config) -> int:
def _get_continue_on_failure(config: Config) -> bool: def _get_continue_on_failure(config: Config) -> bool:
continue_on_failure: bool = config.getvalue("doctest_continue_on_failure") continue_on_failure: bool = config.getvalue('doctest_continue_on_failure')
if continue_on_failure: if continue_on_failure:
# We need to turn off this if we use pdb since we should stop at # We need to turn off this if we use pdb since we should stop at
# the first failure. # the first failure.
if config.getvalue("usepdb"): if config.getvalue('usepdb'):
continue_on_failure = False continue_on_failure = False
return continue_on_failure return continue_on_failure
@ -427,11 +429,11 @@ class DoctestTextfile(Module):
# Inspired by doctest.testfile; ideally we would use it directly, # Inspired by doctest.testfile; ideally we would use it directly,
# but it doesn't support passing a custom checker. # but it doesn't support passing a custom checker.
encoding = self.config.getini("doctest_encoding") encoding = self.config.getini('doctest_encoding')
text = self.path.read_text(encoding) text = self.path.read_text(encoding)
filename = str(self.path) filename = str(self.path)
name = self.path.name name = self.path.name
globs = {"__name__": "__main__"} globs = {'__name__': '__main__'}
optionflags = get_optionflags(self.config) optionflags = get_optionflags(self.config)
@ -446,25 +448,25 @@ class DoctestTextfile(Module):
test = parser.get_doctest(text, globs, name, filename, 0) test = parser.get_doctest(text, globs, name, filename, 0)
if test.examples: if test.examples:
yield DoctestItem.from_parent( yield DoctestItem.from_parent(
self, name=test.name, runner=runner, dtest=test self, name=test.name, runner=runner, dtest=test,
) )
def _check_all_skipped(test: "doctest.DocTest") -> None: def _check_all_skipped(test: doctest.DocTest) -> None:
"""Raise pytest.skip() if all examples in the given DocTest have the SKIP """Raise pytest.skip() if all examples in the given DocTest have the SKIP
option set.""" option set."""
import doctest import doctest
all_skipped = all(x.options.get(doctest.SKIP, False) for x in test.examples) all_skipped = all(x.options.get(doctest.SKIP, False) for x in test.examples)
if all_skipped: if all_skipped:
skip("all tests skipped by +SKIP option") skip('all tests skipped by +SKIP option')
def _is_mocked(obj: object) -> bool: def _is_mocked(obj: object) -> bool:
"""Return if an object is possibly a mock object by checking the """Return if an object is possibly a mock object by checking the
existence of a highly improbable attribute.""" existence of a highly improbable attribute."""
return ( return (
safe_getattr(obj, "pytest_mock_example_attribute_that_shouldnt_exist", None) safe_getattr(obj, 'pytest_mock_example_attribute_that_shouldnt_exist', None)
is not None is not None
) )
@ -476,7 +478,7 @@ def _patch_unwrap_mock_aware() -> Generator[None, None, None]:
real_unwrap = inspect.unwrap real_unwrap = inspect.unwrap
def _mock_aware_unwrap( def _mock_aware_unwrap(
func: Callable[..., Any], *, stop: Optional[Callable[[Any], Any]] = None func: Callable[..., Any], *, stop: Callable[[Any], Any] | None = None,
) -> Any: ) -> Any:
try: try:
if stop is None or stop is _is_mocked: if stop is None or stop is _is_mocked:
@ -485,9 +487,9 @@ def _patch_unwrap_mock_aware() -> Generator[None, None, None]:
return real_unwrap(func, stop=lambda obj: _is_mocked(obj) or _stop(func)) return real_unwrap(func, stop=lambda obj: _is_mocked(obj) or _stop(func))
except Exception as e: except Exception as e:
warnings.warn( warnings.warn(
f"Got {e!r} when unwrapping {func!r}. This is usually caused " f'Got {e!r} when unwrapping {func!r}. This is usually caused '
"by a violation of Python's object protocol; see e.g. " "by a violation of Python's object protocol; see e.g. "
"https://github.com/pytest-dev/pytest/issues/5080", 'https://github.com/pytest-dev/pytest/issues/5080',
PytestWarning, PytestWarning,
) )
raise raise
@ -518,9 +520,9 @@ class DoctestModule(Module):
line number is returned. This will be reported upstream. #8796 line number is returned. This will be reported upstream. #8796
""" """
if isinstance(obj, property): if isinstance(obj, property):
obj = getattr(obj, "fget", obj) obj = getattr(obj, 'fget', obj)
if hasattr(obj, "__wrapped__"): if hasattr(obj, '__wrapped__'):
# Get the main obj in case of it being wrapped # Get the main obj in case of it being wrapped
obj = inspect.unwrap(obj) obj = inspect.unwrap(obj)
@ -531,14 +533,14 @@ class DoctestModule(Module):
) )
def _find( def _find(
self, tests, obj, name, module, source_lines, globs, seen self, tests, obj, name, module, source_lines, globs, seen,
) -> None: ) -> None:
if _is_mocked(obj): if _is_mocked(obj):
return return
with _patch_unwrap_mock_aware(): with _patch_unwrap_mock_aware():
# Type ignored because this is a private function. # Type ignored because this is a private function.
super()._find( # type:ignore[misc] super()._find( # type:ignore[misc]
tests, obj, name, module, source_lines, globs, seen tests, obj, name, module, source_lines, globs, seen,
) )
if sys.version_info < (3, 13): if sys.version_info < (3, 13):
@ -561,8 +563,8 @@ class DoctestModule(Module):
try: try:
module = self.obj module = self.obj
except Collector.CollectError: except Collector.CollectError:
if self.config.getvalue("doctest_ignore_import_errors"): if self.config.getvalue('doctest_ignore_import_errors'):
skip("unable to import module %r" % self.path) skip('unable to import module %r' % self.path)
else: else:
raise raise
@ -583,11 +585,11 @@ class DoctestModule(Module):
for test in finder.find(module, module.__name__): for test in finder.find(module, module.__name__):
if test.examples: # skip empty doctests if test.examples: # skip empty doctests
yield DoctestItem.from_parent( yield DoctestItem.from_parent(
self, name=test.name, runner=runner, dtest=test self, name=test.name, runner=runner, dtest=test,
) )
def _init_checker_class() -> Type["doctest.OutputChecker"]: def _init_checker_class() -> type[doctest.OutputChecker]:
import doctest import doctest
import re import re
@ -633,7 +635,7 @@ def _init_checker_class() -> Type["doctest.OutputChecker"]:
return False return False
def remove_prefixes(regex: Pattern[str], txt: str) -> str: def remove_prefixes(regex: Pattern[str], txt: str) -> str:
return re.sub(regex, r"\1\2", txt) return re.sub(regex, r'\1\2', txt)
if allow_unicode: if allow_unicode:
want = remove_prefixes(self._unicode_literal_re, want) want = remove_prefixes(self._unicode_literal_re, want)
@ -655,10 +657,10 @@ def _init_checker_class() -> Type["doctest.OutputChecker"]:
return got return got
offset = 0 offset = 0
for w, g in zip(wants, gots): for w, g in zip(wants, gots):
fraction: Optional[str] = w.group("fraction") fraction: str | None = w.group('fraction')
exponent: Optional[str] = w.group("exponent1") exponent: str | None = w.group('exponent1')
if exponent is None: if exponent is None:
exponent = w.group("exponent2") exponent = w.group('exponent2')
precision = 0 if fraction is None else len(fraction) precision = 0 if fraction is None else len(fraction)
if exponent is not None: if exponent is not None:
precision -= int(exponent) precision -= int(exponent)
@ -667,7 +669,7 @@ def _init_checker_class() -> Type["doctest.OutputChecker"]:
# got with the text we want, so that it will match when we # got with the text we want, so that it will match when we
# check the string literally. # check the string literally.
got = ( got = (
got[: g.start() + offset] + w.group() + got[g.end() + offset :] got[: g.start() + offset] + w.group() + got[g.end() + offset:]
) )
offset += w.end() - w.start() - (g.end() - g.start()) offset += w.end() - w.start() - (g.end() - g.start())
return got return got
@ -675,7 +677,7 @@ def _init_checker_class() -> Type["doctest.OutputChecker"]:
return LiteralsOutputChecker return LiteralsOutputChecker
def _get_checker() -> "doctest.OutputChecker": def _get_checker() -> doctest.OutputChecker:
"""Return a doctest.OutputChecker subclass that supports some """Return a doctest.OutputChecker subclass that supports some
additional options: additional options:
@ -699,21 +701,21 @@ def _get_allow_unicode_flag() -> int:
"""Register and return the ALLOW_UNICODE flag.""" """Register and return the ALLOW_UNICODE flag."""
import doctest import doctest
return doctest.register_optionflag("ALLOW_UNICODE") return doctest.register_optionflag('ALLOW_UNICODE')
def _get_allow_bytes_flag() -> int: def _get_allow_bytes_flag() -> int:
"""Register and return the ALLOW_BYTES flag.""" """Register and return the ALLOW_BYTES flag."""
import doctest import doctest
return doctest.register_optionflag("ALLOW_BYTES") return doctest.register_optionflag('ALLOW_BYTES')
def _get_number_flag() -> int: def _get_number_flag() -> int:
"""Register and return the NUMBER flag.""" """Register and return the NUMBER flag."""
import doctest import doctest
return doctest.register_optionflag("NUMBER") return doctest.register_optionflag('NUMBER')
def _get_report_choice(key: str) -> int: def _get_report_choice(key: str) -> int:
@ -733,8 +735,8 @@ def _get_report_choice(key: str) -> int:
}[key] }[key]
@fixture(scope="session") @fixture(scope='session')
def doctest_namespace() -> Dict[str, Any]: def doctest_namespace() -> dict[str, Any]:
"""Fixture that returns a :py:class:`dict` that will be injected into the """Fixture that returns a :py:class:`dict` that will be injected into the
namespace of doctests. namespace of doctests.

View file

@ -1,12 +1,14 @@
from __future__ import annotations
import os import os
import sys import sys
from typing import Generator from typing import Generator
import pytest
from _pytest.config import Config from _pytest.config import Config
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.nodes import Item from _pytest.nodes import Item
from _pytest.stash import StashKey from _pytest.stash import StashKey
import pytest
fault_handler_original_stderr_fd_key = StashKey[int]() fault_handler_original_stderr_fd_key = StashKey[int]()
@ -15,10 +17,10 @@ fault_handler_stderr_fd_key = StashKey[int]()
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
help = ( help = (
"Dump the traceback of all threads if a test takes " 'Dump the traceback of all threads if a test takes '
"more than TIMEOUT seconds to finish" 'more than TIMEOUT seconds to finish'
) )
parser.addini("faulthandler_timeout", help, default=0.0) parser.addini('faulthandler_timeout', help, default=0.0)
def pytest_configure(config: Config) -> None: def pytest_configure(config: Config) -> None:
@ -66,7 +68,7 @@ def get_stderr_fileno() -> int:
def get_timeout_config_value(config: Config) -> float: def get_timeout_config_value(config: Config) -> float:
return float(config.getini("faulthandler_timeout") or 0.0) return float(config.getini('faulthandler_timeout') or 0.0)
@pytest.hookimpl(wrapper=True, trylast=True) @pytest.hookimpl(wrapper=True, trylast=True)

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,6 @@
"""Provides a function to report all internal modules for using freezing """Provides a function to report all internal modules for using freezing
tools.""" tools."""
from __future__ import annotations
import types import types
from typing import Iterator from typing import Iterator
@ -7,7 +8,7 @@ from typing import List
from typing import Union from typing import Union
def freeze_includes() -> List[str]: def freeze_includes() -> list[str]:
"""Return a list of module names used by pytest that should be """Return a list of module names used by pytest that should be
included by cx_freeze.""" included by cx_freeze."""
import _pytest import _pytest
@ -17,8 +18,8 @@ def freeze_includes() -> List[str]:
def _iter_all_modules( def _iter_all_modules(
package: Union[str, types.ModuleType], package: str | types.ModuleType,
prefix: str = "", prefix: str = '',
) -> Iterator[str]: ) -> Iterator[str]:
"""Iterate over the names of all modules that can be found in the given """Iterate over the names of all modules that can be found in the given
package, recursively. package, recursively.
@ -36,10 +37,10 @@ def _iter_all_modules(
# Type ignored because typeshed doesn't define ModuleType.__path__ # Type ignored because typeshed doesn't define ModuleType.__path__
# (only defined on packages). # (only defined on packages).
package_path = package.__path__ # type: ignore[attr-defined] package_path = package.__path__ # type: ignore[attr-defined]
path, prefix = package_path[0], package.__name__ + "." path, prefix = package_path[0], package.__name__ + '.'
for _, name, is_package in pkgutil.iter_modules([path]): for _, name, is_package in pkgutil.iter_modules([path]):
if is_package: if is_package:
for m in _iter_all_modules(os.path.join(path, name), prefix=name + "."): for m in _iter_all_modules(os.path.join(path, name), prefix=name + '.'):
yield prefix + m yield prefix + m
else: else:
yield prefix + name yield prefix + name

View file

@ -1,19 +1,21 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Version info, help messages, tracing configuration.""" """Version info, help messages, tracing configuration."""
from argparse import Action from __future__ import annotations
import os import os
import sys import sys
from argparse import Action
from typing import Generator from typing import Generator
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
import pytest
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import ExitCode from _pytest.config import ExitCode
from _pytest.config import PrintHelp from _pytest.config import PrintHelp
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
import pytest
class HelpAction(Action): class HelpAction(Action):
@ -40,63 +42,63 @@ class HelpAction(Action):
setattr(namespace, self.dest, self.const) setattr(namespace, self.dest, self.const)
# We should only skip the rest of the parsing after preparse is done. # We should only skip the rest of the parsing after preparse is done.
if getattr(parser._parser, "after_preparse", False): if getattr(parser._parser, 'after_preparse', False):
raise PrintHelp raise PrintHelp
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig") group = parser.getgroup('debugconfig')
group.addoption( group.addoption(
"--version", '--version',
"-V", '-V',
action="count", action='count',
default=0, default=0,
dest="version", dest='version',
help="Display pytest version and information about plugins. " help='Display pytest version and information about plugins. '
"When given twice, also display information about plugins.", 'When given twice, also display information about plugins.',
) )
group._addoption( group._addoption(
"-h", '-h',
"--help", '--help',
action=HelpAction, action=HelpAction,
dest="help", dest='help',
help="Show help message and configuration info", help='Show help message and configuration info',
) )
group._addoption( group._addoption(
"-p", '-p',
action="append", action='append',
dest="plugins", dest='plugins',
default=[], default=[],
metavar="name", metavar='name',
help="Early-load given plugin module name or entry point (multi-allowed). " help='Early-load given plugin module name or entry point (multi-allowed). '
"To avoid loading of plugins, use the `no:` prefix, e.g. " 'To avoid loading of plugins, use the `no:` prefix, e.g. '
"`no:doctest`.", '`no:doctest`.',
) )
group.addoption( group.addoption(
"--traceconfig", '--traceconfig',
"--trace-config", '--trace-config',
action="store_true", action='store_true',
default=False, default=False,
help="Trace considerations of conftest.py files", help='Trace considerations of conftest.py files',
) )
group.addoption( group.addoption(
"--debug", '--debug',
action="store", action='store',
nargs="?", nargs='?',
const="pytestdebug.log", const='pytestdebug.log',
dest="debug", dest='debug',
metavar="DEBUG_FILE_NAME", metavar='DEBUG_FILE_NAME',
help="Store internal tracing debug information in this log file. " help='Store internal tracing debug information in this log file. '
"This file is opened with 'w' and truncated as a result, care advised. " "This file is opened with 'w' and truncated as a result, care advised. "
"Default: pytestdebug.log.", 'Default: pytestdebug.log.',
) )
group._addoption( group._addoption(
"-o", '-o',
"--override-ini", '--override-ini',
dest="override_ini", dest='override_ini',
action="append", action='append',
help='Override ini option with "option=value" style, ' help='Override ini option with "option=value" style, '
"e.g. `-o xfail_strict=True -o cache_dir=cache`.", 'e.g. `-o xfail_strict=True -o cache_dir=cache`.',
) )
@ -107,24 +109,24 @@ def pytest_cmdline_parse() -> Generator[None, Config, Config]:
if config.option.debug: if config.option.debug:
# --debug | --debug <file.log> was provided. # --debug | --debug <file.log> was provided.
path = config.option.debug path = config.option.debug
debugfile = open(path, "w", encoding="utf-8") debugfile = open(path, 'w', encoding='utf-8')
debugfile.write( debugfile.write(
"versions pytest-{}, " 'versions pytest-{}, '
"python-{}\ninvocation_dir={}\ncwd={}\nargs={}\n\n".format( 'python-{}\ninvocation_dir={}\ncwd={}\nargs={}\n\n'.format(
pytest.__version__, pytest.__version__,
".".join(map(str, sys.version_info)), '.'.join(map(str, sys.version_info)),
config.invocation_params.dir, config.invocation_params.dir,
os.getcwd(), os.getcwd(),
config.invocation_params.args, config.invocation_params.args,
) ),
) )
config.trace.root.setwriter(debugfile.write) config.trace.root.setwriter(debugfile.write)
undo_tracing = config.pluginmanager.enable_tracing() undo_tracing = config.pluginmanager.enable_tracing()
sys.stderr.write("writing pytest debug information to %s\n" % path) sys.stderr.write('writing pytest debug information to %s\n' % path)
def unset_tracing() -> None: def unset_tracing() -> None:
debugfile.close() debugfile.close()
sys.stderr.write("wrote pytest debug information to %s\n" % debugfile.name) sys.stderr.write('wrote pytest debug information to %s\n' % debugfile.name)
config.trace.root.setwriter(None) config.trace.root.setwriter(None)
undo_tracing() undo_tracing()
@ -136,17 +138,17 @@ def pytest_cmdline_parse() -> Generator[None, Config, Config]:
def showversion(config: Config) -> None: def showversion(config: Config) -> None:
if config.option.version > 1: if config.option.version > 1:
sys.stdout.write( sys.stdout.write(
f"This is pytest version {pytest.__version__}, imported from {pytest.__file__}\n" f'This is pytest version {pytest.__version__}, imported from {pytest.__file__}\n',
) )
plugininfo = getpluginversioninfo(config) plugininfo = getpluginversioninfo(config)
if plugininfo: if plugininfo:
for line in plugininfo: for line in plugininfo:
sys.stdout.write(line + "\n") sys.stdout.write(line + '\n')
else: else:
sys.stdout.write(f"pytest {pytest.__version__}\n") sys.stdout.write(f'pytest {pytest.__version__}\n')
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.version > 0: if config.option.version > 0:
showversion(config) showversion(config)
return 0 return 0
@ -161,30 +163,30 @@ def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
def showhelp(config: Config) -> None: def showhelp(config: Config) -> None:
import textwrap import textwrap
reporter: Optional[TerminalReporter] = config.pluginmanager.get_plugin( reporter: TerminalReporter | None = config.pluginmanager.get_plugin(
"terminalreporter" 'terminalreporter',
) )
assert reporter is not None assert reporter is not None
tw = reporter._tw tw = reporter._tw
tw.write(config._parser.optparser.format_help()) tw.write(config._parser.optparser.format_help())
tw.line() tw.line()
tw.line( tw.line(
"[pytest] ini-options in the first " '[pytest] ini-options in the first '
"pytest.ini|tox.ini|setup.cfg|pyproject.toml file found:" 'pytest.ini|tox.ini|setup.cfg|pyproject.toml file found:',
) )
tw.line() tw.line()
columns = tw.fullwidth # costly call columns = tw.fullwidth # costly call
indent_len = 24 # based on argparse's max_help_position=24 indent_len = 24 # based on argparse's max_help_position=24
indent = " " * indent_len indent = ' ' * indent_len
for name in config._parser._ininames: for name in config._parser._ininames:
help, type, default = config._parser._inidict[name] help, type, default = config._parser._inidict[name]
if type is None: if type is None:
type = "string" type = 'string'
if help is None: if help is None:
raise TypeError(f"help argument cannot be None for {name}") raise TypeError(f'help argument cannot be None for {name}')
spec = f"{name} ({type}):" spec = f'{name} ({type}):'
tw.write(" %s" % spec) tw.write(' %s' % spec)
spec_len = len(spec) spec_len = len(spec)
if spec_len > (indent_len - 3): if spec_len > (indent_len - 3):
# Display help starting at a new line. # Display help starting at a new line.
@ -201,7 +203,7 @@ def showhelp(config: Config) -> None:
tw.line(line) tw.line(line)
else: else:
# Display help starting after the spec, following lines indented. # Display help starting after the spec, following lines indented.
tw.write(" " * (indent_len - spec_len - 2)) tw.write(' ' * (indent_len - spec_len - 2))
wrapped = textwrap.wrap(help, columns - indent_len, break_on_hyphens=False) wrapped = textwrap.wrap(help, columns - indent_len, break_on_hyphens=False)
if wrapped: if wrapped:
@ -210,62 +212,62 @@ def showhelp(config: Config) -> None:
tw.line(indent + line) tw.line(indent + line)
tw.line() tw.line()
tw.line("Environment variables:") tw.line('Environment variables:')
vars = [ vars = [
("PYTEST_ADDOPTS", "Extra command line options"), ('PYTEST_ADDOPTS', 'Extra command line options'),
("PYTEST_PLUGINS", "Comma-separated plugins to load during startup"), ('PYTEST_PLUGINS', 'Comma-separated plugins to load during startup'),
("PYTEST_DISABLE_PLUGIN_AUTOLOAD", "Set to disable plugin auto-loading"), ('PYTEST_DISABLE_PLUGIN_AUTOLOAD', 'Set to disable plugin auto-loading'),
("PYTEST_DEBUG", "Set to enable debug tracing of pytest's internals"), ('PYTEST_DEBUG', "Set to enable debug tracing of pytest's internals"),
] ]
for name, help in vars: for name, help in vars:
tw.line(f" {name:<24} {help}") tw.line(f' {name:<24} {help}')
tw.line() tw.line()
tw.line() tw.line()
tw.line("to see available markers type: pytest --markers") tw.line('to see available markers type: pytest --markers')
tw.line("to see available fixtures type: pytest --fixtures") tw.line('to see available fixtures type: pytest --fixtures')
tw.line( tw.line(
"(shown according to specified file_or_dir or current dir " '(shown according to specified file_or_dir or current dir '
"if not specified; fixtures with leading '_' are only shown " "if not specified; fixtures with leading '_' are only shown "
"with the '-v' option" "with the '-v' option",
) )
for warningreport in reporter.stats.get("warnings", []): for warningreport in reporter.stats.get('warnings', []):
tw.line("warning : " + warningreport.message, red=True) tw.line('warning : ' + warningreport.message, red=True)
return return
conftest_options = [("pytest_plugins", "list of plugin names to load")] conftest_options = [('pytest_plugins', 'list of plugin names to load')]
def getpluginversioninfo(config: Config) -> List[str]: def getpluginversioninfo(config: Config) -> list[str]:
lines = [] lines = []
plugininfo = config.pluginmanager.list_plugin_distinfo() plugininfo = config.pluginmanager.list_plugin_distinfo()
if plugininfo: if plugininfo:
lines.append("setuptools registered plugins:") lines.append('setuptools registered plugins:')
for plugin, dist in plugininfo: for plugin, dist in plugininfo:
loc = getattr(plugin, "__file__", repr(plugin)) loc = getattr(plugin, '__file__', repr(plugin))
content = f"{dist.project_name}-{dist.version} at {loc}" content = f'{dist.project_name}-{dist.version} at {loc}'
lines.append(" " + content) lines.append(' ' + content)
return lines return lines
def pytest_report_header(config: Config) -> List[str]: def pytest_report_header(config: Config) -> list[str]:
lines = [] lines = []
if config.option.debug or config.option.traceconfig: if config.option.debug or config.option.traceconfig:
lines.append(f"using: pytest-{pytest.__version__}") lines.append(f'using: pytest-{pytest.__version__}')
verinfo = getpluginversioninfo(config) verinfo = getpluginversioninfo(config)
if verinfo: if verinfo:
lines.extend(verinfo) lines.extend(verinfo)
if config.option.traceconfig: if config.option.traceconfig:
lines.append("active plugins:") lines.append('active plugins:')
items = config.pluginmanager.list_name_plugin() items = config.pluginmanager.list_name_plugin()
for name, plugin in items: for name, plugin in items:
if hasattr(plugin, "__file__"): if hasattr(plugin, '__file__'):
r = plugin.__file__ r = plugin.__file__
else: else:
r = repr(plugin) r = repr(plugin)
lines.append(f" {name:<20}: {r}") lines.append(f' {name:<20}: {r}')
return lines return lines

View file

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Hook specifications for pytest plugins which are invoked by pytest itself """Hook specifications for pytest plugins which are invoked by pytest itself
and by builtin plugins.""" and by builtin plugins."""
from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from typing import Dict from typing import Dict
@ -45,7 +47,7 @@ if TYPE_CHECKING:
from _pytest.terminal import TestShortLogReport from _pytest.terminal import TestShortLogReport
hookspec = HookspecMarker("pytest") hookspec = HookspecMarker('pytest')
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# Initialization hooks called for every plugin # Initialization hooks called for every plugin
@ -53,7 +55,7 @@ hookspec = HookspecMarker("pytest")
@hookspec(historic=True) @hookspec(historic=True)
def pytest_addhooks(pluginmanager: "PytestPluginManager") -> None: def pytest_addhooks(pluginmanager: PytestPluginManager) -> None:
"""Called at plugin registration time to allow adding new hooks via a call to """Called at plugin registration time to allow adding new hooks via a call to
:func:`pluginmanager.add_hookspecs(module_or_class, prefix) <pytest.PytestPluginManager.add_hookspecs>`. :func:`pluginmanager.add_hookspecs(module_or_class, prefix) <pytest.PytestPluginManager.add_hookspecs>`.
@ -72,9 +74,9 @@ def pytest_addhooks(pluginmanager: "PytestPluginManager") -> None:
@hookspec(historic=True) @hookspec(historic=True)
def pytest_plugin_registered( def pytest_plugin_registered(
plugin: "_PluggyPlugin", plugin: _PluggyPlugin,
plugin_name: str, plugin_name: str,
manager: "PytestPluginManager", manager: PytestPluginManager,
) -> None: ) -> None:
"""A new pytest plugin got registered. """A new pytest plugin got registered.
@ -96,7 +98,7 @@ def pytest_plugin_registered(
@hookspec(historic=True) @hookspec(historic=True)
def pytest_addoption(parser: "Parser", pluginmanager: "PytestPluginManager") -> None: def pytest_addoption(parser: Parser, pluginmanager: PytestPluginManager) -> None:
"""Register argparse-style options and ini-style config values, """Register argparse-style options and ini-style config values,
called once at the beginning of a test run. called once at the beginning of a test run.
@ -137,7 +139,7 @@ def pytest_addoption(parser: "Parser", pluginmanager: "PytestPluginManager") ->
@hookspec(historic=True) @hookspec(historic=True)
def pytest_configure(config: "Config") -> None: def pytest_configure(config: Config) -> None:
"""Allow plugins and conftest files to perform initial configuration. """Allow plugins and conftest files to perform initial configuration.
.. note:: .. note::
@ -162,8 +164,8 @@ def pytest_configure(config: "Config") -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_cmdline_parse( def pytest_cmdline_parse(
pluginmanager: "PytestPluginManager", args: List[str] pluginmanager: PytestPluginManager, args: list[str],
) -> Optional["Config"]: ) -> Config | None:
"""Return an initialized :class:`~pytest.Config`, parsing the specified args. """Return an initialized :class:`~pytest.Config`, parsing the specified args.
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
@ -185,7 +187,7 @@ def pytest_cmdline_parse(
def pytest_load_initial_conftests( def pytest_load_initial_conftests(
early_config: "Config", parser: "Parser", args: List[str] early_config: Config, parser: Parser, args: list[str],
) -> None: ) -> None:
"""Called to implement the loading of :ref:`initial conftest files """Called to implement the loading of :ref:`initial conftest files
<pluginorder>` ahead of command line option parsing. <pluginorder>` ahead of command line option parsing.
@ -202,7 +204,7 @@ def pytest_load_initial_conftests(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_cmdline_main(config: "Config") -> Optional[Union["ExitCode", int]]: def pytest_cmdline_main(config: Config) -> ExitCode | int | None:
"""Called for performing the main command line action. """Called for performing the main command line action.
The default implementation will invoke the configure hooks and The default implementation will invoke the configure hooks and
@ -226,7 +228,7 @@ def pytest_cmdline_main(config: "Config") -> Optional[Union["ExitCode", int]]:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_collection(session: "Session") -> Optional[object]: def pytest_collection(session: Session) -> object | None:
"""Perform the collection phase for the given session. """Perform the collection phase for the given session.
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
@ -268,7 +270,7 @@ def pytest_collection(session: "Session") -> Optional[object]:
def pytest_collection_modifyitems( def pytest_collection_modifyitems(
session: "Session", config: "Config", items: List["Item"] session: Session, config: Config, items: list[Item],
) -> None: ) -> None:
"""Called after collection has been performed. May filter or re-order """Called after collection has been performed. May filter or re-order
the items in-place. the items in-place.
@ -284,7 +286,7 @@ def pytest_collection_modifyitems(
""" """
def pytest_collection_finish(session: "Session") -> None: def pytest_collection_finish(session: Session) -> None:
"""Called after collection has been performed and modified. """Called after collection has been performed and modified.
:param session: The pytest session object. :param session: The pytest session object.
@ -298,8 +300,8 @@ def pytest_collection_finish(session: "Session") -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_ignore_collect( def pytest_ignore_collect(
collection_path: Path, path: "LEGACY_PATH", config: "Config" collection_path: Path, path: LEGACY_PATH, config: Config,
) -> Optional[bool]: ) -> bool | None:
"""Return True to prevent considering this path for collection. """Return True to prevent considering this path for collection.
This hook is consulted for all files and directories prior to calling This hook is consulted for all files and directories prior to calling
@ -327,7 +329,7 @@ def pytest_ignore_collect(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_collect_directory(path: Path, parent: "Collector") -> "Optional[Collector]": def pytest_collect_directory(path: Path, parent: Collector) -> Optional[Collector]:
"""Create a :class:`~pytest.Collector` for the given directory, or None if """Create a :class:`~pytest.Collector` for the given directory, or None if
not relevant. not relevant.
@ -356,8 +358,8 @@ def pytest_collect_directory(path: Path, parent: "Collector") -> "Optional[Colle
def pytest_collect_file( def pytest_collect_file(
file_path: Path, path: "LEGACY_PATH", parent: "Collector" file_path: Path, path: LEGACY_PATH, parent: Collector,
) -> "Optional[Collector]": ) -> Optional[Collector]:
"""Create a :class:`~pytest.Collector` for the given path, or None if not relevant. """Create a :class:`~pytest.Collector` for the given path, or None if not relevant.
For best results, the returned collector should be a subclass of For best results, the returned collector should be a subclass of
@ -384,7 +386,7 @@ def pytest_collect_file(
# logging hooks for collection # logging hooks for collection
def pytest_collectstart(collector: "Collector") -> None: def pytest_collectstart(collector: Collector) -> None:
"""Collector starts collecting. """Collector starts collecting.
:param collector: :param collector:
@ -399,7 +401,7 @@ def pytest_collectstart(collector: "Collector") -> None:
""" """
def pytest_itemcollected(item: "Item") -> None: def pytest_itemcollected(item: Item) -> None:
"""We just collected a test item. """We just collected a test item.
:param item: :param item:
@ -413,7 +415,7 @@ def pytest_itemcollected(item: "Item") -> None:
""" """
def pytest_collectreport(report: "CollectReport") -> None: def pytest_collectreport(report: CollectReport) -> None:
"""Collector finished collecting. """Collector finished collecting.
:param report: :param report:
@ -428,7 +430,7 @@ def pytest_collectreport(report: "CollectReport") -> None:
""" """
def pytest_deselected(items: Sequence["Item"]) -> None: def pytest_deselected(items: Sequence[Item]) -> None:
"""Called for deselected test items, e.g. by keyword. """Called for deselected test items, e.g. by keyword.
May be called multiple times. May be called multiple times.
@ -444,7 +446,7 @@ def pytest_deselected(items: Sequence["Item"]) -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_make_collect_report(collector: "Collector") -> "Optional[CollectReport]": def pytest_make_collect_report(collector: Collector) -> Optional[CollectReport]:
"""Perform :func:`collector.collect() <pytest.Collector.collect>` and return """Perform :func:`collector.collect() <pytest.Collector.collect>` and return
a :class:`~pytest.CollectReport`. a :class:`~pytest.CollectReport`.
@ -469,8 +471,8 @@ def pytest_make_collect_report(collector: "Collector") -> "Optional[CollectRepor
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_pycollect_makemodule( def pytest_pycollect_makemodule(
module_path: Path, path: "LEGACY_PATH", parent module_path: Path, path: LEGACY_PATH, parent,
) -> Optional["Module"]: ) -> Module | None:
"""Return a :class:`pytest.Module` collector or None for the given path. """Return a :class:`pytest.Module` collector or None for the given path.
This hook will be called for each matching test module path. This hook will be called for each matching test module path.
@ -499,8 +501,8 @@ def pytest_pycollect_makemodule(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_pycollect_makeitem( def pytest_pycollect_makeitem(
collector: Union["Module", "Class"], name: str, obj: object collector: Module | Class, name: str, obj: object,
) -> Union[None, "Item", "Collector", List[Union["Item", "Collector"]]]: ) -> None | Item | Collector | list[Item | Collector]:
"""Return a custom item/collector for a Python object in a module, or None. """Return a custom item/collector for a Python object in a module, or None.
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
@ -524,7 +526,7 @@ def pytest_pycollect_makeitem(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]: def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
"""Call underlying test function. """Call underlying test function.
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
@ -541,7 +543,7 @@ def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
""" """
def pytest_generate_tests(metafunc: "Metafunc") -> None: def pytest_generate_tests(metafunc: Metafunc) -> None:
"""Generate (multiple) parametrized calls to a test function. """Generate (multiple) parametrized calls to a test function.
:param metafunc: :param metafunc:
@ -558,8 +560,8 @@ def pytest_generate_tests(metafunc: "Metafunc") -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_make_parametrize_id( def pytest_make_parametrize_id(
config: "Config", val: object, argname: str config: Config, val: object, argname: str,
) -> Optional[str]: ) -> str | None:
"""Return a user-friendly string representation of the given ``val`` """Return a user-friendly string representation of the given ``val``
that will be used by @pytest.mark.parametrize calls, or None if the hook that will be used by @pytest.mark.parametrize calls, or None if the hook
doesn't know about ``val``. doesn't know about ``val``.
@ -585,7 +587,7 @@ def pytest_make_parametrize_id(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_runtestloop(session: "Session") -> Optional[object]: def pytest_runtestloop(session: Session) -> object | None:
"""Perform the main runtest loop (after collection finished). """Perform the main runtest loop (after collection finished).
The default hook implementation performs the runtest protocol for all items The default hook implementation performs the runtest protocol for all items
@ -612,8 +614,8 @@ def pytest_runtestloop(session: "Session") -> Optional[object]:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_runtest_protocol( def pytest_runtest_protocol(
item: "Item", nextitem: "Optional[Item]" item: Item, nextitem: Optional[Item],
) -> Optional[object]: ) -> object | None:
"""Perform the runtest protocol for a single test item. """Perform the runtest protocol for a single test item.
The default runtest protocol is this (see individual hooks for full details): The default runtest protocol is this (see individual hooks for full details):
@ -654,7 +656,7 @@ def pytest_runtest_protocol(
def pytest_runtest_logstart( def pytest_runtest_logstart(
nodeid: str, location: Tuple[str, Optional[int], str] nodeid: str, location: tuple[str, int | None, str],
) -> None: ) -> None:
"""Called at the start of running the runtest protocol for a single item. """Called at the start of running the runtest protocol for a single item.
@ -674,7 +676,7 @@ def pytest_runtest_logstart(
def pytest_runtest_logfinish( def pytest_runtest_logfinish(
nodeid: str, location: Tuple[str, Optional[int], str] nodeid: str, location: tuple[str, int | None, str],
) -> None: ) -> None:
"""Called at the end of running the runtest protocol for a single item. """Called at the end of running the runtest protocol for a single item.
@ -693,7 +695,7 @@ def pytest_runtest_logfinish(
""" """
def pytest_runtest_setup(item: "Item") -> None: def pytest_runtest_setup(item: Item) -> None:
"""Called to perform the setup phase for a test item. """Called to perform the setup phase for a test item.
The default implementation runs ``setup()`` on ``item`` and all of its The default implementation runs ``setup()`` on ``item`` and all of its
@ -712,7 +714,7 @@ def pytest_runtest_setup(item: "Item") -> None:
""" """
def pytest_runtest_call(item: "Item") -> None: def pytest_runtest_call(item: Item) -> None:
"""Called to run the test for test item (the call phase). """Called to run the test for test item (the call phase).
The default implementation calls ``item.runtest()``. The default implementation calls ``item.runtest()``.
@ -728,7 +730,7 @@ def pytest_runtest_call(item: "Item") -> None:
""" """
def pytest_runtest_teardown(item: "Item", nextitem: Optional["Item"]) -> None: def pytest_runtest_teardown(item: Item, nextitem: Item | None) -> None:
"""Called to perform the teardown phase for a test item. """Called to perform the teardown phase for a test item.
The default implementation runs the finalizers and calls ``teardown()`` The default implementation runs the finalizers and calls ``teardown()``
@ -754,8 +756,8 @@ def pytest_runtest_teardown(item: "Item", nextitem: Optional["Item"]) -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_runtest_makereport( def pytest_runtest_makereport(
item: "Item", call: "CallInfo[None]" item: Item, call: CallInfo[None],
) -> Optional["TestReport"]: ) -> TestReport | None:
"""Called to create a :class:`~pytest.TestReport` for each of """Called to create a :class:`~pytest.TestReport` for each of
the setup, call and teardown runtest phases of a test item. the setup, call and teardown runtest phases of a test item.
@ -774,7 +776,7 @@ def pytest_runtest_makereport(
""" """
def pytest_runtest_logreport(report: "TestReport") -> None: def pytest_runtest_logreport(report: TestReport) -> None:
"""Process the :class:`~pytest.TestReport` produced for each """Process the :class:`~pytest.TestReport` produced for each
of the setup, call and teardown runtest phases of an item. of the setup, call and teardown runtest phases of an item.
@ -790,9 +792,9 @@ def pytest_runtest_logreport(report: "TestReport") -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_report_to_serializable( def pytest_report_to_serializable(
config: "Config", config: Config,
report: Union["CollectReport", "TestReport"], report: CollectReport | TestReport,
) -> Optional[Dict[str, Any]]: ) -> dict[str, Any] | None:
"""Serialize the given report object into a data structure suitable for """Serialize the given report object into a data structure suitable for
sending over the wire, e.g. converted to JSON. sending over the wire, e.g. converted to JSON.
@ -809,9 +811,9 @@ def pytest_report_to_serializable(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_report_from_serializable( def pytest_report_from_serializable(
config: "Config", config: Config,
data: Dict[str, Any], data: dict[str, Any],
) -> Optional[Union["CollectReport", "TestReport"]]: ) -> CollectReport | TestReport | None:
"""Restore a report object previously serialized with """Restore a report object previously serialized with
:hook:`pytest_report_to_serializable`. :hook:`pytest_report_to_serializable`.
@ -832,8 +834,8 @@ def pytest_report_from_serializable(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_fixture_setup( def pytest_fixture_setup(
fixturedef: "FixtureDef[Any]", request: "SubRequest" fixturedef: FixtureDef[Any], request: SubRequest,
) -> Optional[object]: ) -> object | None:
"""Perform fixture setup execution. """Perform fixture setup execution.
:param fixturdef: :param fixturdef:
@ -860,7 +862,7 @@ def pytest_fixture_setup(
def pytest_fixture_post_finalizer( def pytest_fixture_post_finalizer(
fixturedef: "FixtureDef[Any]", request: "SubRequest" fixturedef: FixtureDef[Any], request: SubRequest,
) -> None: ) -> None:
"""Called after fixture teardown, but before the cache is cleared, so """Called after fixture teardown, but before the cache is cleared, so
the fixture result ``fixturedef.cached_result`` is still available (not the fixture result ``fixturedef.cached_result`` is still available (not
@ -885,7 +887,7 @@ def pytest_fixture_post_finalizer(
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def pytest_sessionstart(session: "Session") -> None: def pytest_sessionstart(session: Session) -> None:
"""Called after the ``Session`` object has been created and before performing collection """Called after the ``Session`` object has been created and before performing collection
and entering the run test loop. and entering the run test loop.
@ -899,8 +901,8 @@ def pytest_sessionstart(session: "Session") -> None:
def pytest_sessionfinish( def pytest_sessionfinish(
session: "Session", session: Session,
exitstatus: Union[int, "ExitCode"], exitstatus: int | ExitCode,
) -> None: ) -> None:
"""Called after whole test run finished, right before returning the exit status to the system. """Called after whole test run finished, right before returning the exit status to the system.
@ -914,7 +916,7 @@ def pytest_sessionfinish(
""" """
def pytest_unconfigure(config: "Config") -> None: def pytest_unconfigure(config: Config) -> None:
"""Called before test process is exited. """Called before test process is exited.
:param config: The pytest config object. :param config: The pytest config object.
@ -932,8 +934,8 @@ def pytest_unconfigure(config: "Config") -> None:
def pytest_assertrepr_compare( def pytest_assertrepr_compare(
config: "Config", op: str, left: object, right: object config: Config, op: str, left: object, right: object,
) -> Optional[List[str]]: ) -> list[str] | None:
"""Return explanation for comparisons in failing assert expressions. """Return explanation for comparisons in failing assert expressions.
Return None for no custom explanation, otherwise return a list Return None for no custom explanation, otherwise return a list
@ -954,7 +956,7 @@ def pytest_assertrepr_compare(
""" """
def pytest_assertion_pass(item: "Item", lineno: int, orig: str, expl: str) -> None: def pytest_assertion_pass(item: Item, lineno: int, orig: str, expl: str) -> None:
"""Called whenever an assertion passes. """Called whenever an assertion passes.
.. versionadded:: 5.0 .. versionadded:: 5.0
@ -994,8 +996,8 @@ def pytest_assertion_pass(item: "Item", lineno: int, orig: str, expl: str) -> No
def pytest_report_header( # type:ignore[empty-body] def pytest_report_header( # type:ignore[empty-body]
config: "Config", start_path: Path, startdir: "LEGACY_PATH" config: Config, start_path: Path, startdir: LEGACY_PATH,
) -> Union[str, List[str]]: ) -> str | list[str]:
"""Return a string or list of strings to be displayed as header info for terminal reporting. """Return a string or list of strings to be displayed as header info for terminal reporting.
:param config: The pytest config object. :param config: The pytest config object.
@ -1022,11 +1024,11 @@ def pytest_report_header( # type:ignore[empty-body]
def pytest_report_collectionfinish( # type:ignore[empty-body] def pytest_report_collectionfinish( # type:ignore[empty-body]
config: "Config", config: Config,
start_path: Path, start_path: Path,
startdir: "LEGACY_PATH", startdir: LEGACY_PATH,
items: Sequence["Item"], items: Sequence[Item],
) -> Union[str, List[str]]: ) -> str | list[str]:
"""Return a string or list of strings to be displayed after collection """Return a string or list of strings to be displayed after collection
has finished successfully. has finished successfully.
@ -1060,8 +1062,8 @@ def pytest_report_collectionfinish( # type:ignore[empty-body]
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_report_teststatus( # type:ignore[empty-body] def pytest_report_teststatus( # type:ignore[empty-body]
report: Union["CollectReport", "TestReport"], config: "Config" report: CollectReport | TestReport, config: Config,
) -> "TestShortLogReport | Tuple[str, str, Union[str, Tuple[str, Mapping[str, bool]]]]": ) -> TestShortLogReport | Tuple[str, str, Union[str, Tuple[str, Mapping[str, bool]]]]:
"""Return result-category, shortletter and verbose word for status """Return result-category, shortletter and verbose word for status
reporting. reporting.
@ -1092,9 +1094,9 @@ def pytest_report_teststatus( # type:ignore[empty-body]
def pytest_terminal_summary( def pytest_terminal_summary(
terminalreporter: "TerminalReporter", terminalreporter: TerminalReporter,
exitstatus: "ExitCode", exitstatus: ExitCode,
config: "Config", config: Config,
) -> None: ) -> None:
"""Add a section to terminal summary reporting. """Add a section to terminal summary reporting.
@ -1114,10 +1116,10 @@ def pytest_terminal_summary(
@hookspec(historic=True) @hookspec(historic=True)
def pytest_warning_recorded( def pytest_warning_recorded(
warning_message: "warnings.WarningMessage", warning_message: warnings.WarningMessage,
when: "Literal['config', 'collect', 'runtest']", when: Literal['config', 'collect', 'runtest'],
nodeid: str, nodeid: str,
location: Optional[Tuple[str, int, str]], location: tuple[str, int, str] | None,
) -> None: ) -> None:
"""Process a warning captured by the internal pytest warnings plugin. """Process a warning captured by the internal pytest warnings plugin.
@ -1158,8 +1160,8 @@ def pytest_warning_recorded(
def pytest_markeval_namespace( # type:ignore[empty-body] def pytest_markeval_namespace( # type:ignore[empty-body]
config: "Config", config: Config,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Called when constructing the globals dictionary used for """Called when constructing the globals dictionary used for
evaluating string conditions in xfail/skipif markers. evaluating string conditions in xfail/skipif markers.
@ -1187,9 +1189,9 @@ def pytest_markeval_namespace( # type:ignore[empty-body]
def pytest_internalerror( def pytest_internalerror(
excrepr: "ExceptionRepr", excrepr: ExceptionRepr,
excinfo: "ExceptionInfo[BaseException]", excinfo: ExceptionInfo[BaseException],
) -> Optional[bool]: ) -> bool | None:
"""Called for internal errors. """Called for internal errors.
Return True to suppress the fallback handling of printing an Return True to suppress the fallback handling of printing an
@ -1206,7 +1208,7 @@ def pytest_internalerror(
def pytest_keyboard_interrupt( def pytest_keyboard_interrupt(
excinfo: "ExceptionInfo[Union[KeyboardInterrupt, Exit]]", excinfo: ExceptionInfo[Union[KeyboardInterrupt, Exit]],
) -> None: ) -> None:
"""Called for keyboard interrupt. """Called for keyboard interrupt.
@ -1220,9 +1222,9 @@ def pytest_keyboard_interrupt(
def pytest_exception_interact( def pytest_exception_interact(
node: Union["Item", "Collector"], node: Item | Collector,
call: "CallInfo[Any]", call: CallInfo[Any],
report: Union["CollectReport", "TestReport"], report: CollectReport | TestReport,
) -> None: ) -> None:
"""Called when an exception was raised which can potentially be """Called when an exception was raised which can potentially be
interactively handled. interactively handled.
@ -1251,7 +1253,7 @@ def pytest_exception_interact(
""" """
def pytest_enter_pdb(config: "Config", pdb: "pdb.Pdb") -> None: def pytest_enter_pdb(config: Config, pdb: pdb.Pdb) -> None:
"""Called upon pdb.set_trace(). """Called upon pdb.set_trace().
Can be used by plugins to take special action just before the python Can be used by plugins to take special action just before the python
@ -1267,7 +1269,7 @@ def pytest_enter_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
""" """
def pytest_leave_pdb(config: "Config", pdb: "pdb.Pdb") -> None: def pytest_leave_pdb(config: Config, pdb: pdb.Pdb) -> None:
"""Called when leaving pdb (e.g. with continue after pdb.set_trace()). """Called when leaving pdb (e.g. with continue after pdb.set_trace()).
Can be used by plugins to take special action just after the python Can be used by plugins to take special action just after the python

View file

@ -7,11 +7,14 @@ Based on initial code from Ross Lawley.
Output conforms to Output conforms to
https://github.com/jenkinsci/xunit-plugin/blob/master/src/main/resources/org/jenkinsci/plugins/xunit/types/model/xsd/junit-10.xsd https://github.com/jenkinsci/xunit-plugin/blob/master/src/main/resources/org/jenkinsci/plugins/xunit/types/model/xsd/junit-10.xsd
""" """
from datetime import datetime from __future__ import annotations
import functools import functools
import os import os
import platform import platform
import re import re
import xml.etree.ElementTree as ET
from datetime import datetime
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import List from typing import List
@ -19,8 +22,8 @@ from typing import Match
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
import xml.etree.ElementTree as ET
import pytest
from _pytest import nodes from _pytest import nodes
from _pytest import timing from _pytest import timing
from _pytest._code.code import ExceptionRepr from _pytest._code.code import ExceptionRepr
@ -32,10 +35,9 @@ from _pytest.fixtures import FixtureRequest
from _pytest.reports import TestReport from _pytest.reports import TestReport
from _pytest.stash import StashKey from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
import pytest
xml_key = StashKey["LogXML"]() xml_key = StashKey['LogXML']()
def bin_xml_escape(arg: object) -> str: def bin_xml_escape(arg: object) -> str:
@ -52,15 +54,15 @@ def bin_xml_escape(arg: object) -> str:
def repl(matchobj: Match[str]) -> str: def repl(matchobj: Match[str]) -> str:
i = ord(matchobj.group()) i = ord(matchobj.group())
if i <= 0xFF: if i <= 0xFF:
return "#x%02X" % i return '#x%02X' % i
else: else:
return "#x%04X" % i return '#x%04X' % i
# The spec range of valid chars is: # The spec range of valid chars is:
# Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF] # Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
# For an unknown(?) reason, we disallow #x7F (DEL) as well. # For an unknown(?) reason, we disallow #x7F (DEL) as well.
illegal_xml_re = ( illegal_xml_re = (
"[^\u0009\u000A\u000D\u0020-\u007E\u0080-\uD7FF\uE000-\uFFFD\u10000-\u10FFFF]" '[^\u0009\u000A\u000D\u0020-\u007E\u0080-\uD7FF\uE000-\uFFFD\u10000-\u10FFFF]'
) )
return re.sub(illegal_xml_re, repl, str(arg)) return re.sub(illegal_xml_re, repl, str(arg))
@ -76,27 +78,27 @@ def merge_family(left, right) -> None:
families = {} families = {}
families["_base"] = {"testcase": ["classname", "name"]} families['_base'] = {'testcase': ['classname', 'name']}
families["_base_legacy"] = {"testcase": ["file", "line", "url"]} families['_base_legacy'] = {'testcase': ['file', 'line', 'url']}
# xUnit 1.x inherits legacy attributes. # xUnit 1.x inherits legacy attributes.
families["xunit1"] = families["_base"].copy() families['xunit1'] = families['_base'].copy()
merge_family(families["xunit1"], families["_base_legacy"]) merge_family(families['xunit1'], families['_base_legacy'])
# xUnit 2.x uses strict base attributes. # xUnit 2.x uses strict base attributes.
families["xunit2"] = families["_base"] families['xunit2'] = families['_base']
class _NodeReporter: class _NodeReporter:
def __init__(self, nodeid: Union[str, TestReport], xml: "LogXML") -> None: def __init__(self, nodeid: str | TestReport, xml: LogXML) -> None:
self.id = nodeid self.id = nodeid
self.xml = xml self.xml = xml
self.add_stats = self.xml.add_stats self.add_stats = self.xml.add_stats
self.family = self.xml.family self.family = self.xml.family
self.duration = 0.0 self.duration = 0.0
self.properties: List[Tuple[str, str]] = [] self.properties: list[tuple[str, str]] = []
self.nodes: List[ET.Element] = [] self.nodes: list[ET.Element] = []
self.attrs: Dict[str, str] = {} self.attrs: dict[str, str] = {}
def append(self, node: ET.Element) -> None: def append(self, node: ET.Element) -> None:
self.xml.add_stats(node.tag) self.xml.add_stats(node.tag)
@ -108,12 +110,12 @@ class _NodeReporter:
def add_attribute(self, name: str, value: object) -> None: def add_attribute(self, name: str, value: object) -> None:
self.attrs[str(name)] = bin_xml_escape(value) self.attrs[str(name)] = bin_xml_escape(value)
def make_properties_node(self) -> Optional[ET.Element]: def make_properties_node(self) -> ET.Element | None:
"""Return a Junit node containing custom properties, if any.""" """Return a Junit node containing custom properties, if any."""
if self.properties: if self.properties:
properties = ET.Element("properties") properties = ET.Element('properties')
for name, value in self.properties: for name, value in self.properties:
properties.append(ET.Element("property", name=name, value=value)) properties.append(ET.Element('property', name=name, value=value))
return properties return properties
return None return None
@ -123,39 +125,39 @@ class _NodeReporter:
classnames = names[:-1] classnames = names[:-1]
if self.xml.prefix: if self.xml.prefix:
classnames.insert(0, self.xml.prefix) classnames.insert(0, self.xml.prefix)
attrs: Dict[str, str] = { attrs: dict[str, str] = {
"classname": ".".join(classnames), 'classname': '.'.join(classnames),
"name": bin_xml_escape(names[-1]), 'name': bin_xml_escape(names[-1]),
"file": testreport.location[0], 'file': testreport.location[0],
} }
if testreport.location[1] is not None: if testreport.location[1] is not None:
attrs["line"] = str(testreport.location[1]) attrs['line'] = str(testreport.location[1])
if hasattr(testreport, "url"): if hasattr(testreport, 'url'):
attrs["url"] = testreport.url attrs['url'] = testreport.url
self.attrs = attrs self.attrs = attrs
self.attrs.update(existing_attrs) # Restore any user-defined attributes. self.attrs.update(existing_attrs) # Restore any user-defined attributes.
# Preserve legacy testcase behavior. # Preserve legacy testcase behavior.
if self.family == "xunit1": if self.family == 'xunit1':
return return
# Filter out attributes not permitted by this test family. # Filter out attributes not permitted by this test family.
# Including custom attributes because they are not valid here. # Including custom attributes because they are not valid here.
temp_attrs = {} temp_attrs = {}
for key in self.attrs.keys(): for key in self.attrs.keys():
if key in families[self.family]["testcase"]: if key in families[self.family]['testcase']:
temp_attrs[key] = self.attrs[key] temp_attrs[key] = self.attrs[key]
self.attrs = temp_attrs self.attrs = temp_attrs
def to_xml(self) -> ET.Element: def to_xml(self) -> ET.Element:
testcase = ET.Element("testcase", self.attrs, time="%.3f" % self.duration) testcase = ET.Element('testcase', self.attrs, time='%.3f' % self.duration)
properties = self.make_properties_node() properties = self.make_properties_node()
if properties is not None: if properties is not None:
testcase.append(properties) testcase.append(properties)
testcase.extend(self.nodes) testcase.extend(self.nodes)
return testcase return testcase
def _add_simple(self, tag: str, message: str, data: Optional[str] = None) -> None: def _add_simple(self, tag: str, message: str, data: str | None = None) -> None:
node = ET.Element(tag, message=message) node = ET.Element(tag, message=message)
node.text = bin_xml_escape(data) node.text = bin_xml_escape(data)
self.append(node) self.append(node)
@ -167,24 +169,24 @@ class _NodeReporter:
content_out = report.capstdout content_out = report.capstdout
content_log = report.caplog content_log = report.caplog
content_err = report.capstderr content_err = report.capstderr
if self.xml.logging == "no": if self.xml.logging == 'no':
return return
content_all = "" content_all = ''
if self.xml.logging in ["log", "all"]: if self.xml.logging in ['log', 'all']:
content_all = self._prepare_content(content_log, " Captured Log ") content_all = self._prepare_content(content_log, ' Captured Log ')
if self.xml.logging in ["system-out", "out-err", "all"]: if self.xml.logging in ['system-out', 'out-err', 'all']:
content_all += self._prepare_content(content_out, " Captured Out ") content_all += self._prepare_content(content_out, ' Captured Out ')
self._write_content(report, content_all, "system-out") self._write_content(report, content_all, 'system-out')
content_all = "" content_all = ''
if self.xml.logging in ["system-err", "out-err", "all"]: if self.xml.logging in ['system-err', 'out-err', 'all']:
content_all += self._prepare_content(content_err, " Captured Err ") content_all += self._prepare_content(content_err, ' Captured Err ')
self._write_content(report, content_all, "system-err") self._write_content(report, content_all, 'system-err')
content_all = "" content_all = ''
if content_all: if content_all:
self._write_content(report, content_all, "system-out") self._write_content(report, content_all, 'system-out')
def _prepare_content(self, content: str, header: str) -> str: def _prepare_content(self, content: str, header: str) -> str:
return "\n".join([header.center(80, "-"), content, ""]) return '\n'.join([header.center(80, '-'), content, ''])
def _write_content(self, report: TestReport, content: str, jheader: str) -> None: def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
tag = ET.Element(jheader) tag = ET.Element(jheader)
@ -192,65 +194,65 @@ class _NodeReporter:
self.append(tag) self.append(tag)
def append_pass(self, report: TestReport) -> None: def append_pass(self, report: TestReport) -> None:
self.add_stats("passed") self.add_stats('passed')
def append_failure(self, report: TestReport) -> None: def append_failure(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline) # msg = str(report.longrepr.reprtraceback.extraline)
if hasattr(report, "wasxfail"): if hasattr(report, 'wasxfail'):
self._add_simple("skipped", "xfail-marked test passes unexpectedly") self._add_simple('skipped', 'xfail-marked test passes unexpectedly')
else: else:
assert report.longrepr is not None assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr( reprcrash: ReprFileLocation | None = getattr(
report.longrepr, "reprcrash", None report.longrepr, 'reprcrash', None,
) )
if reprcrash is not None: if reprcrash is not None:
message = reprcrash.message message = reprcrash.message
else: else:
message = str(report.longrepr) message = str(report.longrepr)
message = bin_xml_escape(message) message = bin_xml_escape(message)
self._add_simple("failure", message, str(report.longrepr)) self._add_simple('failure', message, str(report.longrepr))
def append_collect_error(self, report: TestReport) -> None: def append_collect_error(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline) # msg = str(report.longrepr.reprtraceback.extraline)
assert report.longrepr is not None assert report.longrepr is not None
self._add_simple("error", "collection failure", str(report.longrepr)) self._add_simple('error', 'collection failure', str(report.longrepr))
def append_collect_skipped(self, report: TestReport) -> None: def append_collect_skipped(self, report: TestReport) -> None:
self._add_simple("skipped", "collection skipped", str(report.longrepr)) self._add_simple('skipped', 'collection skipped', str(report.longrepr))
def append_error(self, report: TestReport) -> None: def append_error(self, report: TestReport) -> None:
assert report.longrepr is not None assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr( reprcrash: ReprFileLocation | None = getattr(
report.longrepr, "reprcrash", None report.longrepr, 'reprcrash', None,
) )
if reprcrash is not None: if reprcrash is not None:
reason = reprcrash.message reason = reprcrash.message
else: else:
reason = str(report.longrepr) reason = str(report.longrepr)
if report.when == "teardown": if report.when == 'teardown':
msg = f'failed on teardown with "{reason}"' msg = f'failed on teardown with "{reason}"'
else: else:
msg = f'failed on setup with "{reason}"' msg = f'failed on setup with "{reason}"'
self._add_simple("error", bin_xml_escape(msg), str(report.longrepr)) self._add_simple('error', bin_xml_escape(msg), str(report.longrepr))
def append_skipped(self, report: TestReport) -> None: def append_skipped(self, report: TestReport) -> None:
if hasattr(report, "wasxfail"): if hasattr(report, 'wasxfail'):
xfailreason = report.wasxfail xfailreason = report.wasxfail
if xfailreason.startswith("reason: "): if xfailreason.startswith('reason: '):
xfailreason = xfailreason[8:] xfailreason = xfailreason[8:]
xfailreason = bin_xml_escape(xfailreason) xfailreason = bin_xml_escape(xfailreason)
skipped = ET.Element("skipped", type="pytest.xfail", message=xfailreason) skipped = ET.Element('skipped', type='pytest.xfail', message=xfailreason)
self.append(skipped) self.append(skipped)
else: else:
assert isinstance(report.longrepr, tuple) assert isinstance(report.longrepr, tuple)
filename, lineno, skipreason = report.longrepr filename, lineno, skipreason = report.longrepr
if skipreason.startswith("Skipped: "): if skipreason.startswith('Skipped: '):
skipreason = skipreason[9:] skipreason = skipreason[9:]
details = f"{filename}:{lineno}: {skipreason}" details = f'{filename}:{lineno}: {skipreason}'
skipped = ET.Element( skipped = ET.Element(
"skipped", type="pytest.skip", message=bin_xml_escape(skipreason) 'skipped', type='pytest.skip', message=bin_xml_escape(skipreason),
) )
skipped.text = bin_xml_escape(details) skipped.text = bin_xml_escape(details)
self.append(skipped) self.append(skipped)
@ -265,17 +267,17 @@ class _NodeReporter:
def _warn_incompatibility_with_xunit2( def _warn_incompatibility_with_xunit2(
request: FixtureRequest, fixture_name: str request: FixtureRequest, fixture_name: str,
) -> None: ) -> None:
"""Emit a PytestWarning about the given fixture being incompatible with newer xunit revisions.""" """Emit a PytestWarning about the given fixture being incompatible with newer xunit revisions."""
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
xml = request.config.stash.get(xml_key, None) xml = request.config.stash.get(xml_key, None)
if xml is not None and xml.family not in ("xunit1", "legacy"): if xml is not None and xml.family not in ('xunit1', 'legacy'):
request.node.warn( request.node.warn(
PytestWarning( PytestWarning(
f"{fixture_name} is incompatible with junit_family '{xml.family}' (use 'legacy' or 'xunit1')" f"{fixture_name} is incompatible with junit_family '{xml.family}' (use 'legacy' or 'xunit1')",
) ),
) )
@ -294,7 +296,7 @@ def record_property(request: FixtureRequest) -> Callable[[str, object], None]:
def test_function(record_property): def test_function(record_property):
record_property("example_key", 1) record_property("example_key", 1)
""" """
_warn_incompatibility_with_xunit2(request, "record_property") _warn_incompatibility_with_xunit2(request, 'record_property')
def append_property(name: str, value: object) -> None: def append_property(name: str, value: object) -> None:
request.node.user_properties.append((name, value)) request.node.user_properties.append((name, value))
@ -312,10 +314,10 @@ def record_xml_attribute(request: FixtureRequest) -> Callable[[str, object], Non
from _pytest.warning_types import PytestExperimentalApiWarning from _pytest.warning_types import PytestExperimentalApiWarning
request.node.warn( request.node.warn(
PytestExperimentalApiWarning("record_xml_attribute is an experimental feature") PytestExperimentalApiWarning('record_xml_attribute is an experimental feature'),
) )
_warn_incompatibility_with_xunit2(request, "record_xml_attribute") _warn_incompatibility_with_xunit2(request, 'record_xml_attribute')
# Declare noop # Declare noop
def add_attr_noop(name: str, value: object) -> None: def add_attr_noop(name: str, value: object) -> None:
@ -336,11 +338,11 @@ def _check_record_param_type(param: str, v: str) -> None:
type.""" type."""
__tracebackhide__ = True __tracebackhide__ = True
if not isinstance(v, str): if not isinstance(v, str):
msg = "{param} parameter needs to be a string, but {g} given" # type: ignore[unreachable] msg = '{param} parameter needs to be a string, but {g} given' # type: ignore[unreachable]
raise TypeError(msg.format(param=param, g=type(v).__name__)) raise TypeError(msg.format(param=param, g=type(v).__name__))
@pytest.fixture(scope="session") @pytest.fixture(scope='session')
def record_testsuite_property(request: FixtureRequest) -> Callable[[str, object], None]: def record_testsuite_property(request: FixtureRequest) -> Callable[[str, object], None]:
"""Record a new ``<property>`` tag as child of the root ``<testsuite>``. """Record a new ``<property>`` tag as child of the root ``<testsuite>``.
@ -371,7 +373,7 @@ def record_testsuite_property(request: FixtureRequest) -> Callable[[str, object]
def record_func(name: str, value: object) -> None: def record_func(name: str, value: object) -> None:
"""No-op function in case --junit-xml was not passed in the command-line.""" """No-op function in case --junit-xml was not passed in the command-line."""
__tracebackhide__ = True __tracebackhide__ = True
_check_record_param_type("name", name) _check_record_param_type('name', name)
xml = request.config.stash.get(xml_key, None) xml = request.config.stash.get(xml_key, None)
if xml is not None: if xml is not None:
@ -380,65 +382,65 @@ def record_testsuite_property(request: FixtureRequest) -> Callable[[str, object]
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting") group = parser.getgroup('terminal reporting')
group.addoption( group.addoption(
"--junitxml", '--junitxml',
"--junit-xml", '--junit-xml',
action="store", action='store',
dest="xmlpath", dest='xmlpath',
metavar="path", metavar='path',
type=functools.partial(filename_arg, optname="--junitxml"), type=functools.partial(filename_arg, optname='--junitxml'),
default=None, default=None,
help="Create junit-xml style report file at given path", help='Create junit-xml style report file at given path',
) )
group.addoption( group.addoption(
"--junitprefix", '--junitprefix',
"--junit-prefix", '--junit-prefix',
action="store", action='store',
metavar="str", metavar='str',
default=None, default=None,
help="Prepend prefix to classnames in junit-xml output", help='Prepend prefix to classnames in junit-xml output',
) )
parser.addini( parser.addini(
"junit_suite_name", "Test suite name for JUnit report", default="pytest" 'junit_suite_name', 'Test suite name for JUnit report', default='pytest',
) )
parser.addini( parser.addini(
"junit_logging", 'junit_logging',
"Write captured log messages to JUnit report: " 'Write captured log messages to JUnit report: '
"one of no|log|system-out|system-err|out-err|all", 'one of no|log|system-out|system-err|out-err|all',
default="no", default='no',
) )
parser.addini( parser.addini(
"junit_log_passing_tests", 'junit_log_passing_tests',
"Capture log information for passing tests to JUnit report: ", 'Capture log information for passing tests to JUnit report: ',
type="bool", type='bool',
default=True, default=True,
) )
parser.addini( parser.addini(
"junit_duration_report", 'junit_duration_report',
"Duration time to report: one of total|call", 'Duration time to report: one of total|call',
default="total", default='total',
) # choices=['total', 'call']) ) # choices=['total', 'call'])
parser.addini( parser.addini(
"junit_family", 'junit_family',
"Emit XML for schema: one of legacy|xunit1|xunit2", 'Emit XML for schema: one of legacy|xunit1|xunit2',
default="xunit2", default='xunit2',
) )
def pytest_configure(config: Config) -> None: def pytest_configure(config: Config) -> None:
xmlpath = config.option.xmlpath xmlpath = config.option.xmlpath
# Prevent opening xmllog on worker nodes (xdist). # Prevent opening xmllog on worker nodes (xdist).
if xmlpath and not hasattr(config, "workerinput"): if xmlpath and not hasattr(config, 'workerinput'):
junit_family = config.getini("junit_family") junit_family = config.getini('junit_family')
config.stash[xml_key] = LogXML( config.stash[xml_key] = LogXML(
xmlpath, xmlpath,
config.option.junitprefix, config.option.junitprefix,
config.getini("junit_suite_name"), config.getini('junit_suite_name'),
config.getini("junit_logging"), config.getini('junit_logging'),
config.getini("junit_duration_report"), config.getini('junit_duration_report'),
junit_family, junit_family,
config.getini("junit_log_passing_tests"), config.getini('junit_log_passing_tests'),
) )
config.pluginmanager.register(config.stash[xml_key]) config.pluginmanager.register(config.stash[xml_key])
@ -450,12 +452,12 @@ def pytest_unconfigure(config: Config) -> None:
config.pluginmanager.unregister(xml) config.pluginmanager.unregister(xml)
def mangle_test_address(address: str) -> List[str]: def mangle_test_address(address: str) -> list[str]:
path, possible_open_bracket, params = address.partition("[") path, possible_open_bracket, params = address.partition('[')
names = path.split("::") names = path.split('::')
# Convert file path to dotted path. # Convert file path to dotted path.
names[0] = names[0].replace(nodes.SEP, ".") names[0] = names[0].replace(nodes.SEP, '.')
names[0] = re.sub(r"\.py$", "", names[0]) names[0] = re.sub(r'\.py$', '', names[0])
# Put any params back. # Put any params back.
names[-1] += possible_open_bracket + params names[-1] += possible_open_bracket + params
return names return names
@ -465,11 +467,11 @@ class LogXML:
def __init__( def __init__(
self, self,
logfile, logfile,
prefix: Optional[str], prefix: str | None,
suite_name: str = "pytest", suite_name: str = 'pytest',
logging: str = "no", logging: str = 'no',
report_duration: str = "total", report_duration: str = 'total',
family="xunit1", family='xunit1',
log_passing_tests: bool = True, log_passing_tests: bool = True,
) -> None: ) -> None:
logfile = os.path.expanduser(os.path.expandvars(logfile)) logfile = os.path.expanduser(os.path.expandvars(logfile))
@ -480,27 +482,27 @@ class LogXML:
self.log_passing_tests = log_passing_tests self.log_passing_tests = log_passing_tests
self.report_duration = report_duration self.report_duration = report_duration
self.family = family self.family = family
self.stats: Dict[str, int] = dict.fromkeys( self.stats: dict[str, int] = dict.fromkeys(
["error", "passed", "failure", "skipped"], 0 ['error', 'passed', 'failure', 'skipped'], 0,
) )
self.node_reporters: Dict[ self.node_reporters: dict[
Tuple[Union[str, TestReport], object], _NodeReporter tuple[str | TestReport, object], _NodeReporter,
] = {} ] = {}
self.node_reporters_ordered: List[_NodeReporter] = [] self.node_reporters_ordered: list[_NodeReporter] = []
self.global_properties: List[Tuple[str, str]] = [] self.global_properties: list[tuple[str, str]] = []
# List of reports that failed on call but teardown is pending. # List of reports that failed on call but teardown is pending.
self.open_reports: List[TestReport] = [] self.open_reports: list[TestReport] = []
self.cnt_double_fail_tests = 0 self.cnt_double_fail_tests = 0
# Replaces convenience family with real family. # Replaces convenience family with real family.
if self.family == "legacy": if self.family == 'legacy':
self.family = "xunit1" self.family = 'xunit1'
def finalize(self, report: TestReport) -> None: def finalize(self, report: TestReport) -> None:
nodeid = getattr(report, "nodeid", report) nodeid = getattr(report, 'nodeid', report)
# Local hack to handle xdist report order. # Local hack to handle xdist report order.
workernode = getattr(report, "node", None) workernode = getattr(report, 'node', None)
reporter = self.node_reporters.pop((nodeid, workernode)) reporter = self.node_reporters.pop((nodeid, workernode))
for propname, propvalue in report.user_properties: for propname, propvalue in report.user_properties:
@ -509,10 +511,10 @@ class LogXML:
if reporter is not None: if reporter is not None:
reporter.finalize() reporter.finalize()
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporter: def node_reporter(self, report: TestReport | str) -> _NodeReporter:
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report) nodeid: str | TestReport = getattr(report, 'nodeid', report)
# Local hack to handle xdist report order. # Local hack to handle xdist report order.
workernode = getattr(report, "node", None) workernode = getattr(report, 'node', None)
key = nodeid, workernode key = nodeid, workernode
@ -561,22 +563,22 @@ class LogXML:
""" """
close_report = None close_report = None
if report.passed: if report.passed:
if report.when == "call": # ignore setup/teardown if report.when == 'call': # ignore setup/teardown
reporter = self._opentestcase(report) reporter = self._opentestcase(report)
reporter.append_pass(report) reporter.append_pass(report)
elif report.failed: elif report.failed:
if report.when == "teardown": if report.when == 'teardown':
# The following vars are needed when xdist plugin is used. # The following vars are needed when xdist plugin is used.
report_wid = getattr(report, "worker_id", None) report_wid = getattr(report, 'worker_id', None)
report_ii = getattr(report, "item_index", None) report_ii = getattr(report, 'item_index', None)
close_report = next( close_report = next(
( (
rep rep
for rep in self.open_reports for rep in self.open_reports
if ( if (
rep.nodeid == report.nodeid rep.nodeid == report.nodeid and
and getattr(rep, "item_index", None) == report_ii getattr(rep, 'item_index', None) == report_ii and
and getattr(rep, "worker_id", None) == report_wid getattr(rep, 'worker_id', None) == report_wid
) )
), ),
None, None,
@ -588,7 +590,7 @@ class LogXML:
self.finalize(close_report) self.finalize(close_report)
self.cnt_double_fail_tests += 1 self.cnt_double_fail_tests += 1
reporter = self._opentestcase(report) reporter = self._opentestcase(report)
if report.when == "call": if report.when == 'call':
reporter.append_failure(report) reporter.append_failure(report)
self.open_reports.append(report) self.open_reports.append(report)
if not self.log_passing_tests: if not self.log_passing_tests:
@ -599,21 +601,21 @@ class LogXML:
reporter = self._opentestcase(report) reporter = self._opentestcase(report)
reporter.append_skipped(report) reporter.append_skipped(report)
self.update_testcase_duration(report) self.update_testcase_duration(report)
if report.when == "teardown": if report.when == 'teardown':
reporter = self._opentestcase(report) reporter = self._opentestcase(report)
reporter.write_captured_output(report) reporter.write_captured_output(report)
self.finalize(report) self.finalize(report)
report_wid = getattr(report, "worker_id", None) report_wid = getattr(report, 'worker_id', None)
report_ii = getattr(report, "item_index", None) report_ii = getattr(report, 'item_index', None)
close_report = next( close_report = next(
( (
rep rep
for rep in self.open_reports for rep in self.open_reports
if ( if (
rep.nodeid == report.nodeid rep.nodeid == report.nodeid and
and getattr(rep, "item_index", None) == report_ii getattr(rep, 'item_index', None) == report_ii and
and getattr(rep, "worker_id", None) == report_wid getattr(rep, 'worker_id', None) == report_wid
) )
), ),
None, None,
@ -624,9 +626,9 @@ class LogXML:
def update_testcase_duration(self, report: TestReport) -> None: def update_testcase_duration(self, report: TestReport) -> None:
"""Accumulate total duration for nodeid from given report and update """Accumulate total duration for nodeid from given report and update
the Junit.testcase with the new total if already created.""" the Junit.testcase with the new total if already created."""
if self.report_duration in {"total", report.when}: if self.report_duration in {'total', report.when}:
reporter = self.node_reporter(report) reporter = self.node_reporter(report)
reporter.duration += getattr(report, "duration", 0.0) reporter.duration += getattr(report, 'duration', 0.0)
def pytest_collectreport(self, report: TestReport) -> None: def pytest_collectreport(self, report: TestReport) -> None:
if not report.passed: if not report.passed:
@ -637,9 +639,9 @@ class LogXML:
reporter.append_collect_skipped(report) reporter.append_collect_skipped(report)
def pytest_internalerror(self, excrepr: ExceptionRepr) -> None: def pytest_internalerror(self, excrepr: ExceptionRepr) -> None:
reporter = self.node_reporter("internal") reporter = self.node_reporter('internal')
reporter.attrs.update(classname="pytest", name="internal") reporter.attrs.update(classname='pytest', name='internal')
reporter._add_simple("error", "internal error", str(excrepr)) reporter._add_simple('error', 'internal error', str(excrepr))
def pytest_sessionstart(self) -> None: def pytest_sessionstart(self) -> None:
self.suite_start_time = timing.time() self.suite_start_time = timing.time()
@ -649,27 +651,27 @@ class LogXML:
# exist_ok avoids filesystem race conditions between checking path existence and requesting creation # exist_ok avoids filesystem race conditions between checking path existence and requesting creation
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
with open(self.logfile, "w", encoding="utf-8") as logfile: with open(self.logfile, 'w', encoding='utf-8') as logfile:
suite_stop_time = timing.time() suite_stop_time = timing.time()
suite_time_delta = suite_stop_time - self.suite_start_time suite_time_delta = suite_stop_time - self.suite_start_time
numtests = ( numtests = (
self.stats["passed"] self.stats['passed'] +
+ self.stats["failure"] self.stats['failure'] +
+ self.stats["skipped"] self.stats['skipped'] +
+ self.stats["error"] self.stats['error'] -
- self.cnt_double_fail_tests self.cnt_double_fail_tests
) )
logfile.write('<?xml version="1.0" encoding="utf-8"?>') logfile.write('<?xml version="1.0" encoding="utf-8"?>')
suite_node = ET.Element( suite_node = ET.Element(
"testsuite", 'testsuite',
name=self.suite_name, name=self.suite_name,
errors=str(self.stats["error"]), errors=str(self.stats['error']),
failures=str(self.stats["failure"]), failures=str(self.stats['failure']),
skipped=str(self.stats["skipped"]), skipped=str(self.stats['skipped']),
tests=str(numtests), tests=str(numtests),
time="%.3f" % suite_time_delta, time='%.3f' % suite_time_delta,
timestamp=datetime.fromtimestamp(self.suite_start_time).isoformat(), timestamp=datetime.fromtimestamp(self.suite_start_time).isoformat(),
hostname=platform.node(), hostname=platform.node(),
) )
@ -678,23 +680,23 @@ class LogXML:
suite_node.append(global_properties) suite_node.append(global_properties)
for node_reporter in self.node_reporters_ordered: for node_reporter in self.node_reporters_ordered:
suite_node.append(node_reporter.to_xml()) suite_node.append(node_reporter.to_xml())
testsuites = ET.Element("testsuites") testsuites = ET.Element('testsuites')
testsuites.append(suite_node) testsuites.append(suite_node)
logfile.write(ET.tostring(testsuites, encoding="unicode")) logfile.write(ET.tostring(testsuites, encoding='unicode'))
def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None: def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None:
terminalreporter.write_sep("-", f"generated xml file: {self.logfile}") terminalreporter.write_sep('-', f'generated xml file: {self.logfile}')
def add_global_property(self, name: str, value: object) -> None: def add_global_property(self, name: str, value: object) -> None:
__tracebackhide__ = True __tracebackhide__ = True
_check_record_param_type("name", name) _check_record_param_type('name', name)
self.global_properties.append((name, bin_xml_escape(value))) self.global_properties.append((name, bin_xml_escape(value)))
def _get_global_properties_node(self) -> Optional[ET.Element]: def _get_global_properties_node(self) -> ET.Element | None:
"""Return a Junit node containing custom properties, if any.""" """Return a Junit node containing custom properties, if any."""
if self.global_properties: if self.global_properties:
properties = ET.Element("properties") properties = ET.Element('properties')
for name, value in self.global_properties: for name, value in self.global_properties:
properties.append(ET.Element("property", name=name, value=value)) properties.append(ET.Element('property', name=name, value=value))
return properties return properties
return None return None

View file

@ -1,10 +1,11 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Add backward compatibility support for the legacy py path type.""" """Add backward compatibility support for the legacy py path type."""
from __future__ import annotations
import dataclasses import dataclasses
from pathlib import Path
import shlex import shlex
import subprocess import subprocess
from pathlib import Path
from typing import Final from typing import Final
from typing import final from typing import final
from typing import List from typing import List
@ -12,8 +13,6 @@ from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union from typing import Union
from iniconfig import SectionWrapper
from _pytest.cacheprovider import Cache from _pytest.cacheprovider import Cache
from _pytest.compat import LEGACY_PATH from _pytest.compat import LEGACY_PATH
from _pytest.compat import legacy_path from _pytest.compat import legacy_path
@ -33,6 +32,7 @@ from _pytest.pytester import Pytester
from _pytest.pytester import RunResult from _pytest.pytester import RunResult
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
from _pytest.tmpdir import TempPathFactory from _pytest.tmpdir import TempPathFactory
from iniconfig import SectionWrapper
if TYPE_CHECKING: if TYPE_CHECKING:
@ -50,8 +50,8 @@ class Testdir:
__test__ = False __test__ = False
CLOSE_STDIN: "Final" = Pytester.CLOSE_STDIN CLOSE_STDIN: Final = Pytester.CLOSE_STDIN
TimeoutExpired: "Final" = Pytester.TimeoutExpired TimeoutExpired: Final = Pytester.TimeoutExpired
def __init__(self, pytester: Pytester, *, _ispytest: bool = False) -> None: def __init__(self, pytester: Pytester, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
@ -95,14 +95,14 @@ class Testdir:
def makefile(self, ext, *args, **kwargs) -> LEGACY_PATH: def makefile(self, ext, *args, **kwargs) -> LEGACY_PATH:
"""See :meth:`Pytester.makefile`.""" """See :meth:`Pytester.makefile`."""
if ext and not ext.startswith("."): if ext and not ext.startswith('.'):
# pytester.makefile is going to throw a ValueError in a way that # pytester.makefile is going to throw a ValueError in a way that
# testdir.makefile did not, because # testdir.makefile did not, because
# pathlib.Path is stricter suffixes than py.path # pathlib.Path is stricter suffixes than py.path
# This ext arguments is likely user error, but since testdir has # This ext arguments is likely user error, but since testdir has
# allowed this, we will prepend "." as a workaround to avoid breaking # allowed this, we will prepend "." as a workaround to avoid breaking
# testdir usage that worked before # testdir usage that worked before
ext = "." + ext ext = '.' + ext
return legacy_path(self._pytester.makefile(ext, *args, **kwargs)) return legacy_path(self._pytester.makefile(ext, *args, **kwargs))
def makeconftest(self, source) -> LEGACY_PATH: def makeconftest(self, source) -> LEGACY_PATH:
@ -145,7 +145,7 @@ class Testdir:
"""See :meth:`Pytester.copy_example`.""" """See :meth:`Pytester.copy_example`."""
return legacy_path(self._pytester.copy_example(name)) return legacy_path(self._pytester.copy_example(name))
def getnode(self, config: Config, arg) -> Optional[Union[Item, Collector]]: def getnode(self, config: Config, arg) -> Item | Collector | None:
"""See :meth:`Pytester.getnode`.""" """See :meth:`Pytester.getnode`."""
return self._pytester.getnode(config, arg) return self._pytester.getnode(config, arg)
@ -153,7 +153,7 @@ class Testdir:
"""See :meth:`Pytester.getpathnode`.""" """See :meth:`Pytester.getpathnode`."""
return self._pytester.getpathnode(path) return self._pytester.getpathnode(path)
def genitems(self, colitems: List[Union[Item, Collector]]) -> List[Item]: def genitems(self, colitems: list[Item | Collector]) -> list[Item]:
"""See :meth:`Pytester.genitems`.""" """See :meth:`Pytester.genitems`."""
return self._pytester.genitems(colitems) return self._pytester.genitems(colitems)
@ -172,7 +172,7 @@ class Testdir:
def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False): def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False):
"""See :meth:`Pytester.inline_run`.""" """See :meth:`Pytester.inline_run`."""
return self._pytester.inline_run( return self._pytester.inline_run(
*args, plugins=plugins, no_reraise_ctrlc=no_reraise_ctrlc *args, plugins=plugins, no_reraise_ctrlc=no_reraise_ctrlc,
) )
def runpytest_inprocess(self, *args, **kwargs) -> RunResult: def runpytest_inprocess(self, *args, **kwargs) -> RunResult:
@ -191,7 +191,7 @@ class Testdir:
"""See :meth:`Pytester.parseconfigure`.""" """See :meth:`Pytester.parseconfigure`."""
return self._pytester.parseconfigure(*args) return self._pytester.parseconfigure(*args)
def getitem(self, source, funcname="test_func"): def getitem(self, source, funcname='test_func'):
"""See :meth:`Pytester.getitem`.""" """See :meth:`Pytester.getitem`."""
return self._pytester.getitem(source, funcname) return self._pytester.getitem(source, funcname)
@ -202,12 +202,12 @@ class Testdir:
def getmodulecol(self, source, configargs=(), withinit=False): def getmodulecol(self, source, configargs=(), withinit=False):
"""See :meth:`Pytester.getmodulecol`.""" """See :meth:`Pytester.getmodulecol`."""
return self._pytester.getmodulecol( return self._pytester.getmodulecol(
source, configargs=configargs, withinit=withinit source, configargs=configargs, withinit=withinit,
) )
def collect_by_name( def collect_by_name(
self, modcol: Collector, name: str self, modcol: Collector, name: str,
) -> Optional[Union[Item, Collector]]: ) -> Item | Collector | None:
"""See :meth:`Pytester.collect_by_name`.""" """See :meth:`Pytester.collect_by_name`."""
return self._pytester.collect_by_name(modcol, name) return self._pytester.collect_by_name(modcol, name)
@ -239,17 +239,17 @@ class Testdir:
return self._pytester.runpytest_subprocess(*args, timeout=timeout) return self._pytester.runpytest_subprocess(*args, timeout=timeout)
def spawn_pytest( def spawn_pytest(
self, string: str, expect_timeout: float = 10.0 self, string: str, expect_timeout: float = 10.0,
) -> "pexpect.spawn": ) -> pexpect.spawn:
"""See :meth:`Pytester.spawn_pytest`.""" """See :meth:`Pytester.spawn_pytest`."""
return self._pytester.spawn_pytest(string, expect_timeout=expect_timeout) return self._pytester.spawn_pytest(string, expect_timeout=expect_timeout)
def spawn(self, cmd: str, expect_timeout: float = 10.0) -> "pexpect.spawn": def spawn(self, cmd: str, expect_timeout: float = 10.0) -> pexpect.spawn:
"""See :meth:`Pytester.spawn`.""" """See :meth:`Pytester.spawn`."""
return self._pytester.spawn(cmd, expect_timeout=expect_timeout) return self._pytester.spawn(cmd, expect_timeout=expect_timeout)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Testdir {self.tmpdir!r}>" return f'<Testdir {self.tmpdir!r}>'
def __str__(self) -> str: def __str__(self) -> str:
return str(self.tmpdir) return str(self.tmpdir)
@ -284,7 +284,7 @@ class TempdirFactory:
_tmppath_factory: TempPathFactory _tmppath_factory: TempPathFactory
def __init__( def __init__(
self, tmppath_factory: TempPathFactory, *, _ispytest: bool = False self, tmppath_factory: TempPathFactory, *, _ispytest: bool = False,
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self._tmppath_factory = tmppath_factory self._tmppath_factory = tmppath_factory
@ -300,7 +300,7 @@ class TempdirFactory:
class LegacyTmpdirPlugin: class LegacyTmpdirPlugin:
@staticmethod @staticmethod
@fixture(scope="session") @fixture(scope='session')
def tmpdir_factory(request: FixtureRequest) -> TempdirFactory: def tmpdir_factory(request: FixtureRequest) -> TempdirFactory:
"""Return a :class:`pytest.TempdirFactory` instance for the test session.""" """Return a :class:`pytest.TempdirFactory` instance for the test session."""
# Set dynamically by pytest_configure(). # Set dynamically by pytest_configure().
@ -374,7 +374,7 @@ def Config_rootdir(self: Config) -> LEGACY_PATH:
return legacy_path(str(self.rootpath)) return legacy_path(str(self.rootpath))
def Config_inifile(self: Config) -> Optional[LEGACY_PATH]: def Config_inifile(self: Config) -> LEGACY_PATH | None:
"""The path to the :ref:`configfile <configfiles>`. """The path to the :ref:`configfile <configfiles>`.
Prefer to use :attr:`inipath`, which is a :class:`pathlib.Path`. Prefer to use :attr:`inipath`, which is a :class:`pathlib.Path`.
@ -395,16 +395,16 @@ def Session_stardir(self: Session) -> LEGACY_PATH:
def Config__getini_unknown_type( def Config__getini_unknown_type(
self, name: str, type: str, value: Union[str, List[str]] self, name: str, type: str, value: str | list[str],
): ):
if type == "pathlist": if type == 'pathlist':
# TODO: This assert is probably not valid in all cases. # TODO: This assert is probably not valid in all cases.
assert self.inipath is not None assert self.inipath is not None
dp = self.inipath.parent dp = self.inipath.parent
input_values = shlex.split(value) if isinstance(value, str) else value input_values = shlex.split(value) if isinstance(value, str) else value
return [legacy_path(str(dp / x)) for x in input_values] return [legacy_path(str(dp / x)) for x in input_values]
else: else:
raise ValueError(f"unknown configuration type: {type}", value) raise ValueError(f'unknown configuration type: {type}', value)
def Node_fspath(self: Node) -> LEGACY_PATH: def Node_fspath(self: Node) -> LEGACY_PATH:
@ -423,35 +423,35 @@ def pytest_load_initial_conftests(early_config: Config) -> None:
early_config.add_cleanup(mp.undo) early_config.add_cleanup(mp.undo)
# Add Cache.makedir(). # Add Cache.makedir().
mp.setattr(Cache, "makedir", Cache_makedir, raising=False) mp.setattr(Cache, 'makedir', Cache_makedir, raising=False)
# Add FixtureRequest.fspath property. # Add FixtureRequest.fspath property.
mp.setattr(FixtureRequest, "fspath", property(FixtureRequest_fspath), raising=False) mp.setattr(FixtureRequest, 'fspath', property(FixtureRequest_fspath), raising=False)
# Add TerminalReporter.startdir property. # Add TerminalReporter.startdir property.
mp.setattr( mp.setattr(
TerminalReporter, "startdir", property(TerminalReporter_startdir), raising=False TerminalReporter, 'startdir', property(TerminalReporter_startdir), raising=False,
) )
# Add Config.{invocation_dir,rootdir,inifile} properties. # Add Config.{invocation_dir,rootdir,inifile} properties.
mp.setattr(Config, "invocation_dir", property(Config_invocation_dir), raising=False) mp.setattr(Config, 'invocation_dir', property(Config_invocation_dir), raising=False)
mp.setattr(Config, "rootdir", property(Config_rootdir), raising=False) mp.setattr(Config, 'rootdir', property(Config_rootdir), raising=False)
mp.setattr(Config, "inifile", property(Config_inifile), raising=False) mp.setattr(Config, 'inifile', property(Config_inifile), raising=False)
# Add Session.startdir property. # Add Session.startdir property.
mp.setattr(Session, "startdir", property(Session_stardir), raising=False) mp.setattr(Session, 'startdir', property(Session_stardir), raising=False)
# Add pathlist configuration type. # Add pathlist configuration type.
mp.setattr(Config, "_getini_unknown_type", Config__getini_unknown_type) mp.setattr(Config, '_getini_unknown_type', Config__getini_unknown_type)
# Add Node.fspath property. # Add Node.fspath property.
mp.setattr(Node, "fspath", property(Node_fspath, Node_fspath_set), raising=False) mp.setattr(Node, 'fspath', property(Node_fspath, Node_fspath_set), raising=False)
@hookimpl @hookimpl
def pytest_configure(config: Config) -> None: def pytest_configure(config: Config) -> None:
"""Installs the LegacyTmpdirPlugin if the ``tmpdir`` plugin is also installed.""" """Installs the LegacyTmpdirPlugin if the ``tmpdir`` plugin is also installed."""
if config.pluginmanager.has_plugin("tmpdir"): if config.pluginmanager.has_plugin('tmpdir'):
mp = MonkeyPatch() mp = MonkeyPatch()
config.add_cleanup(mp.undo) config.add_cleanup(mp.undo)
# Create TmpdirFactory and attach it to the config object. # Create TmpdirFactory and attach it to the config object.
@ -466,15 +466,15 @@ def pytest_configure(config: Config) -> None:
pass pass
else: else:
_tmpdirhandler = TempdirFactory(tmp_path_factory, _ispytest=True) _tmpdirhandler = TempdirFactory(tmp_path_factory, _ispytest=True)
mp.setattr(config, "_tmpdirhandler", _tmpdirhandler, raising=False) mp.setattr(config, '_tmpdirhandler', _tmpdirhandler, raising=False)
config.pluginmanager.register(LegacyTmpdirPlugin, "legacypath-tmpdir") config.pluginmanager.register(LegacyTmpdirPlugin, 'legacypath-tmpdir')
@hookimpl @hookimpl
def pytest_plugin_registered(plugin: object, manager: PytestPluginManager) -> None: def pytest_plugin_registered(plugin: object, manager: PytestPluginManager) -> None:
# pytester is not loaded by default and is commonly loaded from a conftest, # pytester is not loaded by default and is commonly loaded from a conftest,
# so checking for it in `pytest_configure` is not enough. # so checking for it in `pytest_configure` is not enough.
is_pytester = plugin is manager.get_plugin("pytester") is_pytester = plugin is manager.get_plugin('pytester')
if is_pytester and not manager.is_registered(LegacyTestdirPlugin): if is_pytester and not manager.is_registered(LegacyTestdirPlugin):
manager.register(LegacyTestdirPlugin, "legacypath-pytester") manager.register(LegacyTestdirPlugin, 'legacypath-pytester')

View file

@ -1,17 +1,19 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Access and control log capturing.""" """Access and control log capturing."""
from __future__ import annotations
import io
import logging
import os
import re
from contextlib import contextmanager from contextlib import contextmanager
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
from datetime import timezone from datetime import timezone
import io
from io import StringIO from io import StringIO
import logging
from logging import LogRecord from logging import LogRecord
import os
from pathlib import Path from pathlib import Path
import re
from types import TracebackType from types import TracebackType
from typing import AbstractSet from typing import AbstractSet
from typing import Dict from typing import Dict
@ -50,15 +52,15 @@ if TYPE_CHECKING:
else: else:
logging_StreamHandler = logging.StreamHandler logging_StreamHandler = logging.StreamHandler
DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s" DEFAULT_LOG_FORMAT = '%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s'
DEFAULT_LOG_DATE_FORMAT = "%H:%M:%S" DEFAULT_LOG_DATE_FORMAT = '%H:%M:%S'
_ANSI_ESCAPE_SEQ = re.compile(r"\x1b\[[\d;]+m") _ANSI_ESCAPE_SEQ = re.compile(r'\x1b\[[\d;]+m')
caplog_handler_key = StashKey["LogCaptureHandler"]() caplog_handler_key = StashKey['LogCaptureHandler']()
caplog_records_key = StashKey[Dict[str, List[logging.LogRecord]]]() caplog_records_key = StashKey[Dict[str, List[logging.LogRecord]]]()
def _remove_ansi_escape_sequences(text: str) -> str: def _remove_ansi_escape_sequences(text: str) -> str:
return _ANSI_ESCAPE_SEQ.sub("", text) return _ANSI_ESCAPE_SEQ.sub('', text)
class DatetimeFormatter(logging.Formatter): class DatetimeFormatter(logging.Formatter):
@ -67,8 +69,8 @@ class DatetimeFormatter(logging.Formatter):
:func:`time.strftime` in case of microseconds in format string. :func:`time.strftime` in case of microseconds in format string.
""" """
def formatTime(self, record: LogRecord, datefmt: Optional[str] = None) -> str: def formatTime(self, record: LogRecord, datefmt: str | None = None) -> str:
if datefmt and "%f" in datefmt: if datefmt and '%f' in datefmt:
ct = self.converter(record.created) ct = self.converter(record.created)
tz = timezone(timedelta(seconds=ct.tm_gmtoff), ct.tm_zone) tz = timezone(timedelta(seconds=ct.tm_gmtoff), ct.tm_zone)
# Construct `datetime.datetime` object from `struct_time` # Construct `datetime.datetime` object from `struct_time`
@ -85,21 +87,21 @@ class ColoredLevelFormatter(DatetimeFormatter):
log format passed to __init__.""" log format passed to __init__."""
LOGLEVEL_COLOROPTS: Mapping[int, AbstractSet[str]] = { LOGLEVEL_COLOROPTS: Mapping[int, AbstractSet[str]] = {
logging.CRITICAL: {"red"}, logging.CRITICAL: {'red'},
logging.ERROR: {"red", "bold"}, logging.ERROR: {'red', 'bold'},
logging.WARNING: {"yellow"}, logging.WARNING: {'yellow'},
logging.WARN: {"yellow"}, logging.WARN: {'yellow'},
logging.INFO: {"green"}, logging.INFO: {'green'},
logging.DEBUG: {"purple"}, logging.DEBUG: {'purple'},
logging.NOTSET: set(), logging.NOTSET: set(),
} }
LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*(?:\.\d+)?s)") LEVELNAME_FMT_REGEX = re.compile(r'%\(levelname\)([+-.]?\d*(?:\.\d+)?s)')
def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None: def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._terminalwriter = terminalwriter self._terminalwriter = terminalwriter
self._original_fmt = self._style._fmt self._original_fmt = self._style._fmt
self._level_to_fmt_mapping: Dict[int, str] = {} self._level_to_fmt_mapping: dict[int, str] = {}
for level, color_opts in self.LOGLEVEL_COLOROPTS.items(): for level, color_opts in self.LOGLEVEL_COLOROPTS.items():
self.add_color_level(level, *color_opts) self.add_color_level(level, *color_opts)
@ -123,15 +125,15 @@ class ColoredLevelFormatter(DatetimeFormatter):
return return
levelname_fmt = levelname_fmt_match.group() levelname_fmt = levelname_fmt_match.group()
formatted_levelname = levelname_fmt % {"levelname": logging.getLevelName(level)} formatted_levelname = levelname_fmt % {'levelname': logging.getLevelName(level)}
# add ANSI escape sequences around the formatted levelname # add ANSI escape sequences around the formatted levelname
color_kwargs = {name: True for name in color_opts} color_kwargs = {name: True for name in color_opts}
colorized_formatted_levelname = self._terminalwriter.markup( colorized_formatted_levelname = self._terminalwriter.markup(
formatted_levelname, **color_kwargs formatted_levelname, **color_kwargs,
) )
self._level_to_fmt_mapping[level] = self.LEVELNAME_FMT_REGEX.sub( self._level_to_fmt_mapping[level] = self.LEVELNAME_FMT_REGEX.sub(
colorized_formatted_levelname, self._fmt colorized_formatted_levelname, self._fmt,
) )
def format(self, record: logging.LogRecord) -> str: def format(self, record: logging.LogRecord) -> str:
@ -147,12 +149,12 @@ class PercentStyleMultiline(logging.PercentStyle):
formats the message as if each line were logged separately. formats the message as if each line were logged separately.
""" """
def __init__(self, fmt: str, auto_indent: Union[int, str, bool, None]) -> None: def __init__(self, fmt: str, auto_indent: int | str | bool | None) -> None:
super().__init__(fmt) super().__init__(fmt)
self._auto_indent = self._get_auto_indent(auto_indent) self._auto_indent = self._get_auto_indent(auto_indent)
@staticmethod @staticmethod
def _get_auto_indent(auto_indent_option: Union[int, str, bool, None]) -> int: def _get_auto_indent(auto_indent_option: int | str | bool | None) -> int:
"""Determine the current auto indentation setting. """Determine the current auto indentation setting.
Specify auto indent behavior (on/off/fixed) by passing in Specify auto indent behavior (on/off/fixed) by passing in
@ -206,8 +208,8 @@ class PercentStyleMultiline(logging.PercentStyle):
return 0 return 0
def format(self, record: logging.LogRecord) -> str: def format(self, record: logging.LogRecord) -> str:
if "\n" in record.message: if '\n' in record.message:
if hasattr(record, "auto_indent"): if hasattr(record, 'auto_indent'):
# Passed in from the "extra={}" kwarg on the call to logging.log(). # Passed in from the "extra={}" kwarg on the call to logging.log().
auto_indent = self._get_auto_indent(record.auto_indent) # type: ignore[attr-defined] auto_indent = self._get_auto_indent(record.auto_indent) # type: ignore[attr-defined]
else: else:
@ -215,17 +217,17 @@ class PercentStyleMultiline(logging.PercentStyle):
if auto_indent: if auto_indent:
lines = record.message.splitlines() lines = record.message.splitlines()
formatted = self._fmt % {**record.__dict__, "message": lines[0]} formatted = self._fmt % {**record.__dict__, 'message': lines[0]}
if auto_indent < 0: if auto_indent < 0:
indentation = _remove_ansi_escape_sequences(formatted).find( indentation = _remove_ansi_escape_sequences(formatted).find(
lines[0] lines[0],
) )
else: else:
# Optimizes logging by allowing a fixed indentation. # Optimizes logging by allowing a fixed indentation.
indentation = auto_indent indentation = auto_indent
lines[0] = formatted lines[0] = formatted
return ("\n" + " " * indentation).join(lines) return ('\n' + ' ' * indentation).join(lines)
return self._fmt % record.__dict__ return self._fmt % record.__dict__
@ -240,114 +242,114 @@ def get_option_ini(config: Config, *names: str):
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
"""Add options to control log capturing.""" """Add options to control log capturing."""
group = parser.getgroup("logging") group = parser.getgroup('logging')
def add_option_ini(option, dest, default=None, type=None, **kwargs): def add_option_ini(option, dest, default=None, type=None, **kwargs):
parser.addini( parser.addini(
dest, default=default, type=type, help="Default value for " + option dest, default=default, type=type, help='Default value for ' + option,
) )
group.addoption(option, dest=dest, **kwargs) group.addoption(option, dest=dest, **kwargs)
add_option_ini( add_option_ini(
"--log-level", '--log-level',
dest="log_level", dest='log_level',
default=None, default=None,
metavar="LEVEL", metavar='LEVEL',
help=( help=(
"Level of messages to catch/display." 'Level of messages to catch/display.'
" Not set by default, so it depends on the root/parent log handler's" " Not set by default, so it depends on the root/parent log handler's"
' effective level, where it is "WARNING" by default.' ' effective level, where it is "WARNING" by default.'
), ),
) )
add_option_ini( add_option_ini(
"--log-format", '--log-format',
dest="log_format", dest='log_format',
default=DEFAULT_LOG_FORMAT, default=DEFAULT_LOG_FORMAT,
help="Log format used by the logging module", help='Log format used by the logging module',
) )
add_option_ini( add_option_ini(
"--log-date-format", '--log-date-format',
dest="log_date_format", dest='log_date_format',
default=DEFAULT_LOG_DATE_FORMAT, default=DEFAULT_LOG_DATE_FORMAT,
help="Log date format used by the logging module", help='Log date format used by the logging module',
) )
parser.addini( parser.addini(
"log_cli", 'log_cli',
default=False, default=False,
type="bool", type='bool',
help='Enable log display during test run (also known as "live logging")', help='Enable log display during test run (also known as "live logging")',
) )
add_option_ini( add_option_ini(
"--log-cli-level", dest="log_cli_level", default=None, help="CLI logging level" '--log-cli-level', dest='log_cli_level', default=None, help='CLI logging level',
) )
add_option_ini( add_option_ini(
"--log-cli-format", '--log-cli-format',
dest="log_cli_format", dest='log_cli_format',
default=None, default=None,
help="Log format used by the logging module", help='Log format used by the logging module',
) )
add_option_ini( add_option_ini(
"--log-cli-date-format", '--log-cli-date-format',
dest="log_cli_date_format", dest='log_cli_date_format',
default=None, default=None,
help="Log date format used by the logging module", help='Log date format used by the logging module',
) )
add_option_ini( add_option_ini(
"--log-file", '--log-file',
dest="log_file", dest='log_file',
default=None, default=None,
help="Path to a file when logging will be written to", help='Path to a file when logging will be written to',
) )
add_option_ini( add_option_ini(
"--log-file-mode", '--log-file-mode',
dest="log_file_mode", dest='log_file_mode',
default="w", default='w',
choices=["w", "a"], choices=['w', 'a'],
help="Log file open mode", help='Log file open mode',
) )
add_option_ini( add_option_ini(
"--log-file-level", '--log-file-level',
dest="log_file_level", dest='log_file_level',
default=None, default=None,
help="Log file logging level", help='Log file logging level',
) )
add_option_ini( add_option_ini(
"--log-file-format", '--log-file-format',
dest="log_file_format", dest='log_file_format',
default=None, default=None,
help="Log format used by the logging module", help='Log format used by the logging module',
) )
add_option_ini( add_option_ini(
"--log-file-date-format", '--log-file-date-format',
dest="log_file_date_format", dest='log_file_date_format',
default=None, default=None,
help="Log date format used by the logging module", help='Log date format used by the logging module',
) )
add_option_ini( add_option_ini(
"--log-auto-indent", '--log-auto-indent',
dest="log_auto_indent", dest='log_auto_indent',
default=None, default=None,
help="Auto-indent multiline messages passed to the logging module. Accepts true|on, false|off or an integer.", help='Auto-indent multiline messages passed to the logging module. Accepts true|on, false|off or an integer.',
) )
group.addoption( group.addoption(
"--log-disable", '--log-disable',
action="append", action='append',
default=[], default=[],
dest="logger_disable", dest='logger_disable',
help="Disable a logger by name. Can be passed multiple times.", help='Disable a logger by name. Can be passed multiple times.',
) )
_HandlerType = TypeVar("_HandlerType", bound=logging.Handler) _HandlerType = TypeVar('_HandlerType', bound=logging.Handler)
# Not using @contextmanager for performance reasons. # Not using @contextmanager for performance reasons.
class catching_logs(Generic[_HandlerType]): class catching_logs(Generic[_HandlerType]):
"""Context manager that prepares the whole logging machinery properly.""" """Context manager that prepares the whole logging machinery properly."""
__slots__ = ("handler", "level", "orig_level") __slots__ = ('handler', 'level', 'orig_level')
def __init__(self, handler: _HandlerType, level: Optional[int] = None) -> None: def __init__(self, handler: _HandlerType, level: int | None = None) -> None:
self.handler = handler self.handler = handler
self.level = level self.level = level
@ -363,9 +365,9 @@ class catching_logs(Generic[_HandlerType]):
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> None: ) -> None:
root_logger = logging.getLogger() root_logger = logging.getLogger()
if self.level is not None: if self.level is not None:
@ -379,7 +381,7 @@ class LogCaptureHandler(logging_StreamHandler):
def __init__(self) -> None: def __init__(self) -> None:
"""Create a new log handler.""" """Create a new log handler."""
super().__init__(StringIO()) super().__init__(StringIO())
self.records: List[logging.LogRecord] = [] self.records: list[logging.LogRecord] = []
def emit(self, record: logging.LogRecord) -> None: def emit(self, record: logging.LogRecord) -> None:
"""Keep the log records in a list in addition to the log text.""" """Keep the log records in a list in addition to the log text."""
@ -410,10 +412,10 @@ class LogCaptureFixture:
def __init__(self, item: nodes.Node, *, _ispytest: bool = False) -> None: def __init__(self, item: nodes.Node, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self._item = item self._item = item
self._initial_handler_level: Optional[int] = None self._initial_handler_level: int | None = None
# Dict of log name -> log level. # Dict of log name -> log level.
self._initial_logger_levels: Dict[Optional[str], int] = {} self._initial_logger_levels: dict[str | None, int] = {}
self._initial_disabled_logging_level: Optional[int] = None self._initial_disabled_logging_level: int | None = None
def _finalize(self) -> None: def _finalize(self) -> None:
"""Finalize the fixture. """Finalize the fixture.
@ -437,8 +439,8 @@ class LogCaptureFixture:
return self._item.stash[caplog_handler_key] return self._item.stash[caplog_handler_key]
def get_records( def get_records(
self, when: Literal["setup", "call", "teardown"] self, when: Literal['setup', 'call', 'teardown'],
) -> List[logging.LogRecord]: ) -> list[logging.LogRecord]:
"""Get the logging records for one of the possible test phases. """Get the logging records for one of the possible test phases.
:param when: :param when:
@ -457,12 +459,12 @@ class LogCaptureFixture:
return _remove_ansi_escape_sequences(self.handler.stream.getvalue()) return _remove_ansi_escape_sequences(self.handler.stream.getvalue())
@property @property
def records(self) -> List[logging.LogRecord]: def records(self) -> list[logging.LogRecord]:
"""The list of log records.""" """The list of log records."""
return self.handler.records return self.handler.records
@property @property
def record_tuples(self) -> List[Tuple[str, int, str]]: def record_tuples(self) -> list[tuple[str, int, str]]:
"""A list of a stripped down version of log records intended """A list of a stripped down version of log records intended
for use in assertion comparison. for use in assertion comparison.
@ -473,7 +475,7 @@ class LogCaptureFixture:
return [(r.name, r.levelno, r.getMessage()) for r in self.records] return [(r.name, r.levelno, r.getMessage()) for r in self.records]
@property @property
def messages(self) -> List[str]: def messages(self) -> list[str]:
"""A list of format-interpolated log messages. """A list of format-interpolated log messages.
Unlike 'records', which contains the format string and parameters for Unlike 'records', which contains the format string and parameters for
@ -496,7 +498,7 @@ class LogCaptureFixture:
self.handler.clear() self.handler.clear()
def _force_enable_logging( def _force_enable_logging(
self, level: Union[int, str], logger_obj: logging.Logger self, level: int | str, logger_obj: logging.Logger,
) -> int: ) -> int:
"""Enable the desired logging level if the global level was disabled via ``logging.disabled``. """Enable the desired logging level if the global level was disabled via ``logging.disabled``.
@ -529,7 +531,7 @@ class LogCaptureFixture:
return original_disable_level return original_disable_level
def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None: def set_level(self, level: int | str, logger: str | None = None) -> None:
"""Set the threshold level of a logger for the duration of a test. """Set the threshold level of a logger for the duration of a test.
Logging messages which are less severe than this level will not be captured. Logging messages which are less severe than this level will not be captured.
@ -556,7 +558,7 @@ class LogCaptureFixture:
@contextmanager @contextmanager
def at_level( def at_level(
self, level: Union[int, str], logger: Optional[str] = None self, level: int | str, logger: str | None = None,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
"""Context manager that sets the level for capturing of logs. After """Context manager that sets the level for capturing of logs. After
the end of the 'with' statement the level is restored to its original the end of the 'with' statement the level is restored to its original
@ -614,7 +616,7 @@ def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture, None, None]:
result._finalize() result._finalize()
def get_log_level_for_setting(config: Config, *setting_names: str) -> Optional[int]: def get_log_level_for_setting(config: Config, *setting_names: str) -> int | None:
for setting_name in setting_names: for setting_name in setting_names:
log_level = config.getoption(setting_name) log_level = config.getoption(setting_name)
if log_level is None: if log_level is None:
@ -633,14 +635,14 @@ def get_log_level_for_setting(config: Config, *setting_names: str) -> Optional[i
raise UsageError( raise UsageError(
f"'{log_level}' is not recognized as a logging level name for " f"'{log_level}' is not recognized as a logging level name for "
f"'{setting_name}'. Please consider passing the " f"'{setting_name}'. Please consider passing the "
"logging level num instead." 'logging level num instead.',
) from e ) from e
# run after terminalreporter/capturemanager are configured # run after terminalreporter/capturemanager are configured
@hookimpl(trylast=True) @hookimpl(trylast=True)
def pytest_configure(config: Config) -> None: def pytest_configure(config: Config) -> None:
config.pluginmanager.register(LoggingPlugin(config), "logging-plugin") config.pluginmanager.register(LoggingPlugin(config), 'logging-plugin')
class LoggingPlugin: class LoggingPlugin:
@ -656,11 +658,11 @@ class LoggingPlugin:
# Report logging. # Report logging.
self.formatter = self._create_formatter( self.formatter = self._create_formatter(
get_option_ini(config, "log_format"), get_option_ini(config, 'log_format'),
get_option_ini(config, "log_date_format"), get_option_ini(config, 'log_date_format'),
get_option_ini(config, "log_auto_indent"), get_option_ini(config, 'log_auto_indent'),
) )
self.log_level = get_log_level_for_setting(config, "log_level") self.log_level = get_log_level_for_setting(config, 'log_level')
self.caplog_handler = LogCaptureHandler() self.caplog_handler = LogCaptureHandler()
self.caplog_handler.setFormatter(self.formatter) self.caplog_handler.setFormatter(self.formatter)
self.report_handler = LogCaptureHandler() self.report_handler = LogCaptureHandler()
@ -668,52 +670,52 @@ class LoggingPlugin:
# File logging. # File logging.
self.log_file_level = get_log_level_for_setting( self.log_file_level = get_log_level_for_setting(
config, "log_file_level", "log_level" config, 'log_file_level', 'log_level',
) )
log_file = get_option_ini(config, "log_file") or os.devnull log_file = get_option_ini(config, 'log_file') or os.devnull
if log_file != os.devnull: if log_file != os.devnull:
directory = os.path.dirname(os.path.abspath(log_file)) directory = os.path.dirname(os.path.abspath(log_file))
if not os.path.isdir(directory): if not os.path.isdir(directory):
os.makedirs(directory) os.makedirs(directory)
self.log_file_mode = get_option_ini(config, "log_file_mode") or "w" self.log_file_mode = get_option_ini(config, 'log_file_mode') or 'w'
self.log_file_handler = _FileHandler( self.log_file_handler = _FileHandler(
log_file, mode=self.log_file_mode, encoding="UTF-8" log_file, mode=self.log_file_mode, encoding='UTF-8',
) )
log_file_format = get_option_ini(config, "log_file_format", "log_format") log_file_format = get_option_ini(config, 'log_file_format', 'log_format')
log_file_date_format = get_option_ini( log_file_date_format = get_option_ini(
config, "log_file_date_format", "log_date_format" config, 'log_file_date_format', 'log_date_format',
) )
log_file_formatter = DatetimeFormatter( log_file_formatter = DatetimeFormatter(
log_file_format, datefmt=log_file_date_format log_file_format, datefmt=log_file_date_format,
) )
self.log_file_handler.setFormatter(log_file_formatter) self.log_file_handler.setFormatter(log_file_formatter)
# CLI/live logging. # CLI/live logging.
self.log_cli_level = get_log_level_for_setting( self.log_cli_level = get_log_level_for_setting(
config, "log_cli_level", "log_level" config, 'log_cli_level', 'log_level',
) )
if self._log_cli_enabled(): if self._log_cli_enabled():
terminal_reporter = config.pluginmanager.get_plugin("terminalreporter") terminal_reporter = config.pluginmanager.get_plugin('terminalreporter')
# Guaranteed by `_log_cli_enabled()`. # Guaranteed by `_log_cli_enabled()`.
assert terminal_reporter is not None assert terminal_reporter is not None
capture_manager = config.pluginmanager.get_plugin("capturemanager") capture_manager = config.pluginmanager.get_plugin('capturemanager')
# if capturemanager plugin is disabled, live logging still works. # if capturemanager plugin is disabled, live logging still works.
self.log_cli_handler: Union[ self.log_cli_handler: (
_LiveLoggingStreamHandler, _LiveLoggingNullHandler _LiveLoggingStreamHandler | _LiveLoggingNullHandler
] = _LiveLoggingStreamHandler(terminal_reporter, capture_manager) ) = _LiveLoggingStreamHandler(terminal_reporter, capture_manager)
else: else:
self.log_cli_handler = _LiveLoggingNullHandler() self.log_cli_handler = _LiveLoggingNullHandler()
log_cli_formatter = self._create_formatter( log_cli_formatter = self._create_formatter(
get_option_ini(config, "log_cli_format", "log_format"), get_option_ini(config, 'log_cli_format', 'log_format'),
get_option_ini(config, "log_cli_date_format", "log_date_format"), get_option_ini(config, 'log_cli_date_format', 'log_date_format'),
get_option_ini(config, "log_auto_indent"), get_option_ini(config, 'log_auto_indent'),
) )
self.log_cli_handler.setFormatter(log_cli_formatter) self.log_cli_handler.setFormatter(log_cli_formatter)
self._disable_loggers(loggers_to_disable=config.option.logger_disable) self._disable_loggers(loggers_to_disable=config.option.logger_disable)
def _disable_loggers(self, loggers_to_disable: List[str]) -> None: def _disable_loggers(self, loggers_to_disable: list[str]) -> None:
if not loggers_to_disable: if not loggers_to_disable:
return return
@ -723,18 +725,18 @@ class LoggingPlugin:
def _create_formatter(self, log_format, log_date_format, auto_indent): def _create_formatter(self, log_format, log_date_format, auto_indent):
# Color option doesn't exist if terminal plugin is disabled. # Color option doesn't exist if terminal plugin is disabled.
color = getattr(self._config.option, "color", "no") color = getattr(self._config.option, 'color', 'no')
if color != "no" and ColoredLevelFormatter.LEVELNAME_FMT_REGEX.search( if color != 'no' and ColoredLevelFormatter.LEVELNAME_FMT_REGEX.search(
log_format log_format,
): ):
formatter: logging.Formatter = ColoredLevelFormatter( formatter: logging.Formatter = ColoredLevelFormatter(
create_terminal_writer(self._config), log_format, log_date_format create_terminal_writer(self._config), log_format, log_date_format,
) )
else: else:
formatter = DatetimeFormatter(log_format, log_date_format) formatter = DatetimeFormatter(log_format, log_date_format)
formatter._style = PercentStyleMultiline( formatter._style = PercentStyleMultiline(
formatter._style._fmt, auto_indent=auto_indent formatter._style._fmt, auto_indent=auto_indent,
) )
return formatter return formatter
@ -756,7 +758,7 @@ class LoggingPlugin:
fpath.parent.mkdir(exist_ok=True, parents=True) fpath.parent.mkdir(exist_ok=True, parents=True)
# https://github.com/python/mypy/issues/11193 # https://github.com/python/mypy/issues/11193
stream: io.TextIOWrapper = fpath.open(mode=self.log_file_mode, encoding="UTF-8") # type: ignore[assignment] stream: io.TextIOWrapper = fpath.open(mode=self.log_file_mode, encoding='UTF-8') # type: ignore[assignment]
old_stream = self.log_file_handler.setStream(stream) old_stream = self.log_file_handler.setStream(stream)
if old_stream: if old_stream:
old_stream.close() old_stream.close()
@ -764,12 +766,12 @@ class LoggingPlugin:
def _log_cli_enabled(self) -> bool: def _log_cli_enabled(self) -> bool:
"""Return whether live logging is enabled.""" """Return whether live logging is enabled."""
enabled = self._config.getoption( enabled = self._config.getoption(
"--log-cli-level" '--log-cli-level',
) is not None or self._config.getini("log_cli") ) is not None or self._config.getini('log_cli')
if not enabled: if not enabled:
return False return False
terminal_reporter = self._config.pluginmanager.get_plugin("terminalreporter") terminal_reporter = self._config.pluginmanager.get_plugin('terminalreporter')
if terminal_reporter is None: if terminal_reporter is None:
# terminal reporter is disabled e.g. by pytest-xdist. # terminal reporter is disabled e.g. by pytest-xdist.
return False return False
@ -778,7 +780,7 @@ class LoggingPlugin:
@hookimpl(wrapper=True, tryfirst=True) @hookimpl(wrapper=True, tryfirst=True)
def pytest_sessionstart(self) -> Generator[None, None, None]: def pytest_sessionstart(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("sessionstart") self.log_cli_handler.set_when('sessionstart')
with catching_logs(self.log_cli_handler, level=self.log_cli_level): with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level): with catching_logs(self.log_file_handler, level=self.log_file_level):
@ -786,7 +788,7 @@ class LoggingPlugin:
@hookimpl(wrapper=True, tryfirst=True) @hookimpl(wrapper=True, tryfirst=True)
def pytest_collection(self) -> Generator[None, None, None]: def pytest_collection(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("collection") self.log_cli_handler.set_when('collection')
with catching_logs(self.log_cli_handler, level=self.log_cli_level): with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level): with catching_logs(self.log_file_handler, level=self.log_file_level):
@ -797,7 +799,7 @@ class LoggingPlugin:
if session.config.option.collectonly: if session.config.option.collectonly:
return (yield) return (yield)
if self._log_cli_enabled() and self._config.getoption("verbose") < 1: if self._log_cli_enabled() and self._config.getoption('verbose') < 1:
# The verbose flag is needed to avoid messy test progress output. # The verbose flag is needed to avoid messy test progress output.
self._config.option.verbose = 1 self._config.option.verbose = 1
@ -808,11 +810,11 @@ class LoggingPlugin:
@hookimpl @hookimpl
def pytest_runtest_logstart(self) -> None: def pytest_runtest_logstart(self) -> None:
self.log_cli_handler.reset() self.log_cli_handler.reset()
self.log_cli_handler.set_when("start") self.log_cli_handler.set_when('start')
@hookimpl @hookimpl
def pytest_runtest_logreport(self) -> None: def pytest_runtest_logreport(self) -> None:
self.log_cli_handler.set_when("logreport") self.log_cli_handler.set_when('logreport')
def _runtest_for(self, item: nodes.Item, when: str) -> Generator[None, None, None]: def _runtest_for(self, item: nodes.Item, when: str) -> Generator[None, None, None]:
"""Implement the internals of the pytest_runtest_xxx() hooks.""" """Implement the internals of the pytest_runtest_xxx() hooks."""
@ -832,39 +834,39 @@ class LoggingPlugin:
yield yield
finally: finally:
log = report_handler.stream.getvalue().strip() log = report_handler.stream.getvalue().strip()
item.add_report_section(when, "log", log) item.add_report_section(when, 'log', log)
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_runtest_setup(self, item: nodes.Item) -> Generator[None, None, None]: def pytest_runtest_setup(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("setup") self.log_cli_handler.set_when('setup')
empty: Dict[str, List[logging.LogRecord]] = {} empty: dict[str, list[logging.LogRecord]] = {}
item.stash[caplog_records_key] = empty item.stash[caplog_records_key] = empty
yield from self._runtest_for(item, "setup") yield from self._runtest_for(item, 'setup')
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_runtest_call(self, item: nodes.Item) -> Generator[None, None, None]: def pytest_runtest_call(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("call") self.log_cli_handler.set_when('call')
yield from self._runtest_for(item, "call") yield from self._runtest_for(item, 'call')
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_runtest_teardown(self, item: nodes.Item) -> Generator[None, None, None]: def pytest_runtest_teardown(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("teardown") self.log_cli_handler.set_when('teardown')
try: try:
yield from self._runtest_for(item, "teardown") yield from self._runtest_for(item, 'teardown')
finally: finally:
del item.stash[caplog_records_key] del item.stash[caplog_records_key]
del item.stash[caplog_handler_key] del item.stash[caplog_handler_key]
@hookimpl @hookimpl
def pytest_runtest_logfinish(self) -> None: def pytest_runtest_logfinish(self) -> None:
self.log_cli_handler.set_when("finish") self.log_cli_handler.set_when('finish')
@hookimpl(wrapper=True, tryfirst=True) @hookimpl(wrapper=True, tryfirst=True)
def pytest_sessionfinish(self) -> Generator[None, None, None]: def pytest_sessionfinish(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("sessionfinish") self.log_cli_handler.set_when('sessionfinish')
with catching_logs(self.log_cli_handler, level=self.log_cli_level): with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level): with catching_logs(self.log_file_handler, level=self.log_file_level):
@ -901,7 +903,7 @@ class _LiveLoggingStreamHandler(logging_StreamHandler):
def __init__( def __init__(
self, self,
terminal_reporter: TerminalReporter, terminal_reporter: TerminalReporter,
capture_manager: Optional[CaptureManager], capture_manager: CaptureManager | None,
) -> None: ) -> None:
super().__init__(stream=terminal_reporter) # type: ignore[arg-type] super().__init__(stream=terminal_reporter) # type: ignore[arg-type]
self.capture_manager = capture_manager self.capture_manager = capture_manager
@ -913,11 +915,11 @@ class _LiveLoggingStreamHandler(logging_StreamHandler):
"""Reset the handler; should be called before the start of each test.""" """Reset the handler; should be called before the start of each test."""
self._first_record_emitted = False self._first_record_emitted = False
def set_when(self, when: Optional[str]) -> None: def set_when(self, when: str | None) -> None:
"""Prepare for the given test phase (setup/call/teardown).""" """Prepare for the given test phase (setup/call/teardown)."""
self._when = when self._when = when
self._section_name_shown = False self._section_name_shown = False
if when == "start": if when == 'start':
self._test_outcome_written = False self._test_outcome_written = False
def emit(self, record: logging.LogRecord) -> None: def emit(self, record: logging.LogRecord) -> None:
@ -928,14 +930,14 @@ class _LiveLoggingStreamHandler(logging_StreamHandler):
) )
with ctx_manager: with ctx_manager:
if not self._first_record_emitted: if not self._first_record_emitted:
self.stream.write("\n") self.stream.write('\n')
self._first_record_emitted = True self._first_record_emitted = True
elif self._when in ("teardown", "finish"): elif self._when in ('teardown', 'finish'):
if not self._test_outcome_written: if not self._test_outcome_written:
self._test_outcome_written = True self._test_outcome_written = True
self.stream.write("\n") self.stream.write('\n')
if not self._section_name_shown and self._when: if not self._section_name_shown and self._when:
self.stream.section("live log " + self._when, sep="-", bold=True) self.stream.section('live log ' + self._when, sep='-', bold=True)
self._section_name_shown = True self._section_name_shown = True
super().emit(record) super().emit(record)

View file

@ -1,14 +1,15 @@
"""Core implementation of the testing process: init, session, runtest loop.""" """Core implementation of the testing process: init, session, runtest loop."""
from __future__ import annotations
import argparse import argparse
import dataclasses import dataclasses
import fnmatch import fnmatch
import functools import functools
import importlib
import importlib.util import importlib.util
import os import os
from pathlib import Path
import sys import sys
import warnings
from pathlib import Path
from typing import AbstractSet from typing import AbstractSet
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
@ -24,12 +25,10 @@ from typing import Sequence
from typing import Tuple from typing import Tuple
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union from typing import Union
import warnings
import pluggy
from _pytest import nodes
import _pytest._code import _pytest._code
import pluggy
from _pytest import nodes
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import directory_arg from _pytest.config import directory_arg
from _pytest.config import ExitCode from _pytest.config import ExitCode
@ -58,195 +57,195 @@ if TYPE_CHECKING:
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
parser.addini( parser.addini(
"norecursedirs", 'norecursedirs',
"Directory patterns to avoid for recursion", 'Directory patterns to avoid for recursion',
type="args", type='args',
default=[ default=[
"*.egg", '*.egg',
".*", '.*',
"_darcs", '_darcs',
"build", 'build',
"CVS", 'CVS',
"dist", 'dist',
"node_modules", 'node_modules',
"venv", 'venv',
"{arch}", '{arch}',
], ],
) )
parser.addini( parser.addini(
"testpaths", 'testpaths',
"Directories to search for tests when no files or directories are given on the " 'Directories to search for tests when no files or directories are given on the '
"command line", 'command line',
type="args", type='args',
default=[], default=[],
) )
group = parser.getgroup("general", "Running and selection options") group = parser.getgroup('general', 'Running and selection options')
group._addoption( group._addoption(
"-x", '-x',
"--exitfirst", '--exitfirst',
action="store_const", action='store_const',
dest="maxfail", dest='maxfail',
const=1, const=1,
help="Exit instantly on first error or failed test", help='Exit instantly on first error or failed test',
) )
group = parser.getgroup("pytest-warnings") group = parser.getgroup('pytest-warnings')
group.addoption( group.addoption(
"-W", '-W',
"--pythonwarnings", '--pythonwarnings',
action="append", action='append',
help="Set which warnings to report, see -W option of Python itself", help='Set which warnings to report, see -W option of Python itself',
) )
parser.addini( parser.addini(
"filterwarnings", 'filterwarnings',
type="linelist", type='linelist',
help="Each line specifies a pattern for " help='Each line specifies a pattern for '
"warnings.filterwarnings. " 'warnings.filterwarnings. '
"Processed after -W/--pythonwarnings.", 'Processed after -W/--pythonwarnings.',
) )
group._addoption( group._addoption(
"--maxfail", '--maxfail',
metavar="num", metavar='num',
action="store", action='store',
type=int, type=int,
dest="maxfail", dest='maxfail',
default=0, default=0,
help="Exit after first num failures or errors", help='Exit after first num failures or errors',
) )
group._addoption( group._addoption(
"--strict-config", '--strict-config',
action="store_true", action='store_true',
help="Any warnings encountered while parsing the `pytest` section of the " help='Any warnings encountered while parsing the `pytest` section of the '
"configuration file raise errors", 'configuration file raise errors',
) )
group._addoption( group._addoption(
"--strict-markers", '--strict-markers',
action="store_true", action='store_true',
help="Markers not registered in the `markers` section of the configuration " help='Markers not registered in the `markers` section of the configuration '
"file raise errors", 'file raise errors',
) )
group._addoption( group._addoption(
"--strict", '--strict',
action="store_true", action='store_true',
help="(Deprecated) alias to --strict-markers", help='(Deprecated) alias to --strict-markers',
) )
group._addoption( group._addoption(
"-c", '-c',
"--config-file", '--config-file',
metavar="FILE", metavar='FILE',
type=str, type=str,
dest="inifilename", dest='inifilename',
help="Load configuration from `FILE` instead of trying to locate one of the " help='Load configuration from `FILE` instead of trying to locate one of the '
"implicit configuration files.", 'implicit configuration files.',
) )
group._addoption( group._addoption(
"--continue-on-collection-errors", '--continue-on-collection-errors',
action="store_true", action='store_true',
default=False, default=False,
dest="continue_on_collection_errors", dest='continue_on_collection_errors',
help="Force test execution even if collection errors occur", help='Force test execution even if collection errors occur',
) )
group._addoption( group._addoption(
"--rootdir", '--rootdir',
action="store", action='store',
dest="rootdir", dest='rootdir',
help="Define root directory for tests. Can be relative path: 'root_dir', './root_dir', " help="Define root directory for tests. Can be relative path: 'root_dir', './root_dir', "
"'root_dir/another_dir/'; absolute path: '/home/user/root_dir'; path with variables: " "'root_dir/another_dir/'; absolute path: '/home/user/root_dir'; path with variables: "
"'$HOME/root_dir'.", "'$HOME/root_dir'.",
) )
group = parser.getgroup("collect", "collection") group = parser.getgroup('collect', 'collection')
group.addoption( group.addoption(
"--collectonly", '--collectonly',
"--collect-only", '--collect-only',
"--co", '--co',
action="store_true", action='store_true',
help="Only collect tests, don't execute them", help="Only collect tests, don't execute them",
) )
group.addoption( group.addoption(
"--pyargs", '--pyargs',
action="store_true", action='store_true',
help="Try to interpret all arguments as Python packages", help='Try to interpret all arguments as Python packages',
) )
group.addoption( group.addoption(
"--ignore", '--ignore',
action="append", action='append',
metavar="path", metavar='path',
help="Ignore path during collection (multi-allowed)", help='Ignore path during collection (multi-allowed)',
) )
group.addoption( group.addoption(
"--ignore-glob", '--ignore-glob',
action="append", action='append',
metavar="path", metavar='path',
help="Ignore path pattern during collection (multi-allowed)", help='Ignore path pattern during collection (multi-allowed)',
) )
group.addoption( group.addoption(
"--deselect", '--deselect',
action="append", action='append',
metavar="nodeid_prefix", metavar='nodeid_prefix',
help="Deselect item (via node id prefix) during collection (multi-allowed)", help='Deselect item (via node id prefix) during collection (multi-allowed)',
) )
group.addoption( group.addoption(
"--confcutdir", '--confcutdir',
dest="confcutdir", dest='confcutdir',
default=None, default=None,
metavar="dir", metavar='dir',
type=functools.partial(directory_arg, optname="--confcutdir"), type=functools.partial(directory_arg, optname='--confcutdir'),
help="Only load conftest.py's relative to specified dir", help="Only load conftest.py's relative to specified dir",
) )
group.addoption( group.addoption(
"--noconftest", '--noconftest',
action="store_true", action='store_true',
dest="noconftest", dest='noconftest',
default=False, default=False,
help="Don't load any conftest.py files", help="Don't load any conftest.py files",
) )
group.addoption( group.addoption(
"--keepduplicates", '--keepduplicates',
"--keep-duplicates", '--keep-duplicates',
action="store_true", action='store_true',
dest="keepduplicates", dest='keepduplicates',
default=False, default=False,
help="Keep duplicate tests", help='Keep duplicate tests',
) )
group.addoption( group.addoption(
"--collect-in-virtualenv", '--collect-in-virtualenv',
action="store_true", action='store_true',
dest="collect_in_virtualenv", dest='collect_in_virtualenv',
default=False, default=False,
help="Don't ignore tests in a local virtualenv directory", help="Don't ignore tests in a local virtualenv directory",
) )
group.addoption( group.addoption(
"--import-mode", '--import-mode',
default="prepend", default='prepend',
choices=["prepend", "append", "importlib"], choices=['prepend', 'append', 'importlib'],
dest="importmode", dest='importmode',
help="Prepend/append to sys.path when importing test modules and conftest " help='Prepend/append to sys.path when importing test modules and conftest '
"files. Default: prepend.", 'files. Default: prepend.',
) )
parser.addini( parser.addini(
"consider_namespace_packages", 'consider_namespace_packages',
type="bool", type='bool',
default=False, default=False,
help="Consider namespace packages when resolving module names during import", help='Consider namespace packages when resolving module names during import',
) )
group = parser.getgroup("debugconfig", "test session debugging and configuration") group = parser.getgroup('debugconfig', 'test session debugging and configuration')
group.addoption( group.addoption(
"--basetemp", '--basetemp',
dest="basetemp", dest='basetemp',
default=None, default=None,
type=validate_basetemp, type=validate_basetemp,
metavar="dir", metavar='dir',
help=( help=(
"Base temporary directory for this test run. " 'Base temporary directory for this test run. '
"(Warning: this directory is removed if it exists.)" '(Warning: this directory is removed if it exists.)'
), ),
) )
def validate_basetemp(path: str) -> str: def validate_basetemp(path: str) -> str:
# GH 7119 # GH 7119
msg = "basetemp must not be empty, the current working directory or any parent directory of it" msg = 'basetemp must not be empty, the current working directory or any parent directory of it'
# empty path # empty path
if not path: if not path:
@ -270,8 +269,8 @@ def validate_basetemp(path: str) -> str:
def wrap_session( def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]] config: Config, doit: Callable[[Config, Session], int | ExitCode | None],
) -> Union[int, ExitCode]: ) -> int | ExitCode:
"""Skeleton command line program.""" """Skeleton command line program."""
session = Session.from_config(config) session = Session.from_config(config)
session.exitstatus = ExitCode.OK session.exitstatus = ExitCode.OK
@ -290,12 +289,12 @@ def wrap_session(
session.exitstatus = ExitCode.TESTS_FAILED session.exitstatus = ExitCode.TESTS_FAILED
except (KeyboardInterrupt, exit.Exception): except (KeyboardInterrupt, exit.Exception):
excinfo = _pytest._code.ExceptionInfo.from_current() excinfo = _pytest._code.ExceptionInfo.from_current()
exitstatus: Union[int, ExitCode] = ExitCode.INTERRUPTED exitstatus: int | ExitCode = ExitCode.INTERRUPTED
if isinstance(excinfo.value, exit.Exception): if isinstance(excinfo.value, exit.Exception):
if excinfo.value.returncode is not None: if excinfo.value.returncode is not None:
exitstatus = excinfo.value.returncode exitstatus = excinfo.value.returncode
if initstate < 2: if initstate < 2:
sys.stderr.write(f"{excinfo.typename}: {excinfo.value.msg}\n") sys.stderr.write(f'{excinfo.typename}: {excinfo.value.msg}\n')
config.hook.pytest_keyboard_interrupt(excinfo=excinfo) config.hook.pytest_keyboard_interrupt(excinfo=excinfo)
session.exitstatus = exitstatus session.exitstatus = exitstatus
except BaseException: except BaseException:
@ -306,10 +305,10 @@ def wrap_session(
except exit.Exception as exc: except exit.Exception as exc:
if exc.returncode is not None: if exc.returncode is not None:
session.exitstatus = exc.returncode session.exitstatus = exc.returncode
sys.stderr.write(f"{type(exc).__name__}: {exc}\n") sys.stderr.write(f'{type(exc).__name__}: {exc}\n')
else: else:
if isinstance(excinfo.value, SystemExit): if isinstance(excinfo.value, SystemExit):
sys.stderr.write("mainloop: caught unexpected SystemExit!\n") sys.stderr.write('mainloop: caught unexpected SystemExit!\n')
finally: finally:
# Explicitly break reference cycle. # Explicitly break reference cycle.
@ -318,21 +317,21 @@ def wrap_session(
if initstate >= 2: if initstate >= 2:
try: try:
config.hook.pytest_sessionfinish( config.hook.pytest_sessionfinish(
session=session, exitstatus=session.exitstatus session=session, exitstatus=session.exitstatus,
) )
except exit.Exception as exc: except exit.Exception as exc:
if exc.returncode is not None: if exc.returncode is not None:
session.exitstatus = exc.returncode session.exitstatus = exc.returncode
sys.stderr.write(f"{type(exc).__name__}: {exc}\n") sys.stderr.write(f'{type(exc).__name__}: {exc}\n')
config._ensure_unconfigure() config._ensure_unconfigure()
return session.exitstatus return session.exitstatus
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]: def pytest_cmdline_main(config: Config) -> int | ExitCode:
return wrap_session(config, _main) return wrap_session(config, _main)
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]: def _main(config: Config, session: Session) -> int | ExitCode | None:
"""Default command line protocol for initialization, session, """Default command line protocol for initialization, session,
running tests and reporting.""" running tests and reporting."""
config.hook.pytest_collection(session=session) config.hook.pytest_collection(session=session)
@ -345,15 +344,15 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
return None return None
def pytest_collection(session: "Session") -> None: def pytest_collection(session: Session) -> None:
session.perform_collect() session.perform_collect()
def pytest_runtestloop(session: "Session") -> bool: def pytest_runtestloop(session: Session) -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors: if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted( raise session.Interrupted(
"%d error%s during collection" '%d error%s during collection'
% (session.testsfailed, "s" if session.testsfailed != 1 else "") % (session.testsfailed, 's' if session.testsfailed != 1 else ''),
) )
if session.config.option.collectonly: if session.config.option.collectonly:
@ -372,32 +371,32 @@ def pytest_runtestloop(session: "Session") -> bool:
def _in_venv(path: Path) -> bool: def _in_venv(path: Path) -> bool:
"""Attempt to detect if ``path`` is the root of a Virtual Environment by """Attempt to detect if ``path`` is the root of a Virtual Environment by
checking for the existence of the appropriate activate script.""" checking for the existence of the appropriate activate script."""
bindir = path.joinpath("Scripts" if sys.platform.startswith("win") else "bin") bindir = path.joinpath('Scripts' if sys.platform.startswith('win') else 'bin')
try: try:
if not bindir.is_dir(): if not bindir.is_dir():
return False return False
except OSError: except OSError:
return False return False
activates = ( activates = (
"activate", 'activate',
"activate.csh", 'activate.csh',
"activate.fish", 'activate.fish',
"Activate", 'Activate',
"Activate.bat", 'Activate.bat',
"Activate.ps1", 'Activate.ps1',
) )
return any(fname.name in activates for fname in bindir.iterdir()) return any(fname.name in activates for fname in bindir.iterdir())
def pytest_ignore_collect(collection_path: Path, config: Config) -> Optional[bool]: def pytest_ignore_collect(collection_path: Path, config: Config) -> bool | None:
if collection_path.name == "__pycache__": if collection_path.name == '__pycache__':
return True return True
ignore_paths = config._getconftest_pathlist( ignore_paths = config._getconftest_pathlist(
"collect_ignore", path=collection_path.parent 'collect_ignore', path=collection_path.parent,
) )
ignore_paths = ignore_paths or [] ignore_paths = ignore_paths or []
excludeopt = config.getoption("ignore") excludeopt = config.getoption('ignore')
if excludeopt: if excludeopt:
ignore_paths.extend(absolutepath(x) for x in excludeopt) ignore_paths.extend(absolutepath(x) for x in excludeopt)
@ -405,22 +404,22 @@ def pytest_ignore_collect(collection_path: Path, config: Config) -> Optional[boo
return True return True
ignore_globs = config._getconftest_pathlist( ignore_globs = config._getconftest_pathlist(
"collect_ignore_glob", path=collection_path.parent 'collect_ignore_glob', path=collection_path.parent,
) )
ignore_globs = ignore_globs or [] ignore_globs = ignore_globs or []
excludeglobopt = config.getoption("ignore_glob") excludeglobopt = config.getoption('ignore_glob')
if excludeglobopt: if excludeglobopt:
ignore_globs.extend(absolutepath(x) for x in excludeglobopt) ignore_globs.extend(absolutepath(x) for x in excludeglobopt)
if any(fnmatch.fnmatch(str(collection_path), str(glob)) for glob in ignore_globs): if any(fnmatch.fnmatch(str(collection_path), str(glob)) for glob in ignore_globs):
return True return True
allow_in_venv = config.getoption("collect_in_virtualenv") allow_in_venv = config.getoption('collect_in_virtualenv')
if not allow_in_venv and _in_venv(collection_path): if not allow_in_venv and _in_venv(collection_path):
return True return True
if collection_path.is_dir(): if collection_path.is_dir():
norecursepatterns = config.getini("norecursedirs") norecursepatterns = config.getini('norecursedirs')
if any(fnmatch_ex(pat, collection_path) for pat in norecursepatterns): if any(fnmatch_ex(pat, collection_path) for pat in norecursepatterns):
return True return True
@ -428,13 +427,13 @@ def pytest_ignore_collect(collection_path: Path, config: Config) -> Optional[boo
def pytest_collect_directory( def pytest_collect_directory(
path: Path, parent: nodes.Collector path: Path, parent: nodes.Collector,
) -> Optional[nodes.Collector]: ) -> nodes.Collector | None:
return Dir.from_parent(parent, path=path) return Dir.from_parent(parent, path=path)
def pytest_collection_modifyitems(items: List[nodes.Item], config: Config) -> None: def pytest_collection_modifyitems(items: list[nodes.Item], config: Config) -> None:
deselect_prefixes = tuple(config.getoption("deselect") or []) deselect_prefixes = tuple(config.getoption('deselect') or [])
if not deselect_prefixes: if not deselect_prefixes:
return return
@ -469,7 +468,7 @@ class FSHookProxy:
class Interrupted(KeyboardInterrupt): class Interrupted(KeyboardInterrupt):
"""Signals that the test run was interrupted.""" """Signals that the test run was interrupted."""
__module__ = "builtins" # For py3. __module__ = 'builtins' # For py3.
class Failed(Exception): class Failed(Exception):
@ -478,7 +477,7 @@ class Failed(Exception):
@dataclasses.dataclass @dataclasses.dataclass
class _bestrelpath_cache(Dict[Path, str]): class _bestrelpath_cache(Dict[Path, str]):
__slots__ = ("path",) __slots__ = ('path',)
path: Path path: Path
@ -507,7 +506,7 @@ class Dir(nodes.Directory):
parent: nodes.Collector, parent: nodes.Collector,
*, *,
path: Path, path: Path,
) -> "Self": ) -> Self:
"""The public constructor. """The public constructor.
:param parent: The parent collector of this Dir. :param parent: The parent collector of this Dir.
@ -515,9 +514,9 @@ class Dir(nodes.Directory):
""" """
return super().from_parent(parent=parent, path=path) return super().from_parent(parent=parent, path=path)
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: def collect(self) -> Iterable[nodes.Item | nodes.Collector]:
config = self.config config = self.config
col: Optional[nodes.Collector] col: nodes.Collector | None
cols: Sequence[nodes.Collector] cols: Sequence[nodes.Collector]
ihook = self.ihook ihook = self.ihook
for direntry in scandir(self.path): for direntry in scandir(self.path):
@ -552,60 +551,60 @@ class Session(nodes.Collector):
_setupstate: SetupState _setupstate: SetupState
# Set on the session by fixtures.pytest_sessionstart. # Set on the session by fixtures.pytest_sessionstart.
_fixturemanager: FixtureManager _fixturemanager: FixtureManager
exitstatus: Union[int, ExitCode] exitstatus: int | ExitCode
def __init__(self, config: Config) -> None: def __init__(self, config: Config) -> None:
super().__init__( super().__init__(
name="", name='',
path=config.rootpath, path=config.rootpath,
fspath=None, fspath=None,
parent=None, parent=None,
config=config, config=config,
session=self, session=self,
nodeid="", nodeid='',
) )
self.testsfailed = 0 self.testsfailed = 0
self.testscollected = 0 self.testscollected = 0
self._shouldstop: Union[bool, str] = False self._shouldstop: bool | str = False
self._shouldfail: Union[bool, str] = False self._shouldfail: bool | str = False
self.trace = config.trace.root.get("collection") self.trace = config.trace.root.get('collection')
self._initialpaths: FrozenSet[Path] = frozenset() self._initialpaths: frozenset[Path] = frozenset()
self._initialpaths_with_parents: FrozenSet[Path] = frozenset() self._initialpaths_with_parents: frozenset[Path] = frozenset()
self._notfound: List[Tuple[str, Sequence[nodes.Collector]]] = [] self._notfound: list[tuple[str, Sequence[nodes.Collector]]] = []
self._initial_parts: List[CollectionArgument] = [] self._initial_parts: list[CollectionArgument] = []
self._collection_cache: Dict[nodes.Collector, CollectReport] = {} self._collection_cache: dict[nodes.Collector, CollectReport] = {}
self.items: List[nodes.Item] = [] self.items: list[nodes.Item] = []
self._bestrelpathcache: Dict[Path, str] = _bestrelpath_cache(config.rootpath) self._bestrelpathcache: dict[Path, str] = _bestrelpath_cache(config.rootpath)
self.config.pluginmanager.register(self, name="session") self.config.pluginmanager.register(self, name='session')
@classmethod @classmethod
def from_config(cls, config: Config) -> "Session": def from_config(cls, config: Config) -> Session:
session: Session = cls._create(config=config) session: Session = cls._create(config=config)
return session return session
def __repr__(self) -> str: def __repr__(self) -> str:
return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % ( return '<%s %s exitstatus=%r testsfailed=%d testscollected=%d>' % (
self.__class__.__name__, self.__class__.__name__,
self.name, self.name,
getattr(self, "exitstatus", "<UNSET>"), getattr(self, 'exitstatus', '<UNSET>'),
self.testsfailed, self.testsfailed,
self.testscollected, self.testscollected,
) )
@property @property
def shouldstop(self) -> Union[bool, str]: def shouldstop(self) -> bool | str:
return self._shouldstop return self._shouldstop
@shouldstop.setter @shouldstop.setter
def shouldstop(self, value: Union[bool, str]) -> None: def shouldstop(self, value: bool | str) -> None:
# The runner checks shouldfail and assumes that if it is set we are # The runner checks shouldfail and assumes that if it is set we are
# definitely stopping, so prevent unsetting it. # definitely stopping, so prevent unsetting it.
if value is False and self._shouldstop: if value is False and self._shouldstop:
warnings.warn( warnings.warn(
PytestWarning( PytestWarning(
"session.shouldstop cannot be unset after it has been set; ignoring." 'session.shouldstop cannot be unset after it has been set; ignoring.',
), ),
stacklevel=2, stacklevel=2,
) )
@ -613,17 +612,17 @@ class Session(nodes.Collector):
self._shouldstop = value self._shouldstop = value
@property @property
def shouldfail(self) -> Union[bool, str]: def shouldfail(self) -> bool | str:
return self._shouldfail return self._shouldfail
@shouldfail.setter @shouldfail.setter
def shouldfail(self, value: Union[bool, str]) -> None: def shouldfail(self, value: bool | str) -> None:
# The runner checks shouldfail and assumes that if it is set we are # The runner checks shouldfail and assumes that if it is set we are
# definitely stopping, so prevent unsetting it. # definitely stopping, so prevent unsetting it.
if value is False and self._shouldfail: if value is False and self._shouldfail:
warnings.warn( warnings.warn(
PytestWarning( PytestWarning(
"session.shouldfail cannot be unset after it has been set; ignoring." 'session.shouldfail cannot be unset after it has been set; ignoring.',
), ),
stacklevel=2, stacklevel=2,
) )
@ -651,19 +650,19 @@ class Session(nodes.Collector):
@hookimpl(tryfirst=True) @hookimpl(tryfirst=True)
def pytest_runtest_logreport( def pytest_runtest_logreport(
self, report: Union[TestReport, CollectReport] self, report: TestReport | CollectReport,
) -> None: ) -> None:
if report.failed and not hasattr(report, "wasxfail"): if report.failed and not hasattr(report, 'wasxfail'):
self.testsfailed += 1 self.testsfailed += 1
maxfail = self.config.getvalue("maxfail") maxfail = self.config.getvalue('maxfail')
if maxfail and self.testsfailed >= maxfail: if maxfail and self.testsfailed >= maxfail:
self.shouldfail = "stopping after %d failures" % (self.testsfailed) self.shouldfail = 'stopping after %d failures' % (self.testsfailed)
pytest_collectreport = pytest_runtest_logreport pytest_collectreport = pytest_runtest_logreport
def isinitpath( def isinitpath(
self, self,
path: Union[str, "os.PathLike[str]"], path: str | os.PathLike[str],
*, *,
with_parents: bool = False, with_parents: bool = False,
) -> bool: ) -> bool:
@ -685,7 +684,7 @@ class Session(nodes.Collector):
else: else:
return path_ in self._initialpaths return path_ in self._initialpaths
def gethookproxy(self, fspath: "os.PathLike[str]") -> pluggy.HookRelay: def gethookproxy(self, fspath: os.PathLike[str]) -> pluggy.HookRelay:
# Optimization: Path(Path(...)) is much slower than isinstance. # Optimization: Path(Path(...)) is much slower than isinstance.
path = fspath if isinstance(fspath, Path) else Path(fspath) path = fspath if isinstance(fspath, Path) else Path(fspath)
pm = self.config.pluginmanager pm = self.config.pluginmanager
@ -705,7 +704,7 @@ class Session(nodes.Collector):
def _collect_path( def _collect_path(
self, self,
path: Path, path: Path,
path_cache: Dict[Path, Sequence[nodes.Collector]], path_cache: dict[Path, Sequence[nodes.Collector]],
) -> Sequence[nodes.Collector]: ) -> Sequence[nodes.Collector]:
"""Create a Collector for the given path. """Create a Collector for the given path.
@ -717,8 +716,8 @@ class Session(nodes.Collector):
if path.is_dir(): if path.is_dir():
ihook = self.gethookproxy(path.parent) ihook = self.gethookproxy(path.parent)
col: Optional[nodes.Collector] = ihook.pytest_collect_directory( col: nodes.Collector | None = ihook.pytest_collect_directory(
path=path, parent=self path=path, parent=self,
) )
cols: Sequence[nodes.Collector] = (col,) if col is not None else () cols: Sequence[nodes.Collector] = (col,) if col is not None else ()
@ -735,19 +734,19 @@ class Session(nodes.Collector):
@overload @overload
def perform_collect( def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ... self, args: Sequence[str] | None = ..., genitems: Literal[True] = ...,
) -> Sequence[nodes.Item]: ) -> Sequence[nodes.Item]:
... ...
@overload @overload
def perform_collect( def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: bool = ... self, args: Sequence[str] | None = ..., genitems: bool = ...,
) -> Sequence[Union[nodes.Item, nodes.Collector]]: ) -> Sequence[nodes.Item | nodes.Collector]:
... ...
def perform_collect( def perform_collect(
self, args: Optional[Sequence[str]] = None, genitems: bool = True self, args: Sequence[str] | None = None, genitems: bool = True,
) -> Sequence[Union[nodes.Item, nodes.Collector]]: ) -> Sequence[nodes.Item | nodes.Collector]:
"""Perform the collection phase for this session. """Perform the collection phase for this session.
This is called by the default :hook:`pytest_collection` hook This is called by the default :hook:`pytest_collection` hook
@ -764,7 +763,7 @@ class Session(nodes.Collector):
if args is None: if args is None:
args = self.config.args args = self.config.args
self.trace("perform_collect", self, args) self.trace('perform_collect', self, args)
self.trace.root.indent += 1 self.trace.root.indent += 1
hook = self.config.hook hook = self.config.hook
@ -773,10 +772,10 @@ class Session(nodes.Collector):
self._initial_parts = [] self._initial_parts = []
self._collection_cache = {} self._collection_cache = {}
self.items = [] self.items = []
items: Sequence[Union[nodes.Item, nodes.Collector]] = self.items items: Sequence[nodes.Item | nodes.Collector] = self.items
try: try:
initialpaths: List[Path] = [] initialpaths: list[Path] = []
initialpaths_with_parents: List[Path] = [] initialpaths_with_parents: list[Path] = []
for arg in args: for arg in args:
collection_argument = resolve_collection_argument( collection_argument = resolve_collection_argument(
self.config.invocation_params.dir, self.config.invocation_params.dir,
@ -798,10 +797,10 @@ class Session(nodes.Collector):
for arg, collectors in self._notfound: for arg, collectors in self._notfound:
if collectors: if collectors:
errors.append( errors.append(
f"not found: {arg}\n(no match in any of {collectors!r})" f'not found: {arg}\n(no match in any of {collectors!r})',
) )
else: else:
errors.append(f"found no collectors for {arg}") errors.append(f'found no collectors for {arg}')
raise UsageError(*errors) raise UsageError(*errors)
@ -814,7 +813,7 @@ class Session(nodes.Collector):
self.config.pluginmanager.check_pending() self.config.pluginmanager.check_pending()
hook.pytest_collection_modifyitems( hook.pytest_collection_modifyitems(
session=self, config=self.config, items=items session=self, config=self.config, items=items,
) )
finally: finally:
self._notfound = [] self._notfound = []
@ -831,7 +830,7 @@ class Session(nodes.Collector):
self, self,
node: nodes.Collector, node: nodes.Collector,
handle_dupes: bool = True, handle_dupes: bool = True,
) -> Tuple[CollectReport, bool]: ) -> tuple[CollectReport, bool]:
if node in self._collection_cache and handle_dupes: if node in self._collection_cache and handle_dupes:
rep = self._collection_cache[node] rep = self._collection_cache[node]
return rep, True return rep, True
@ -840,16 +839,16 @@ class Session(nodes.Collector):
self._collection_cache[node] = rep self._collection_cache[node] = rep
return rep, False return rep, False
def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]: def collect(self) -> Iterator[nodes.Item | nodes.Collector]:
# This is a cache for the root directories of the initial paths. # This is a cache for the root directories of the initial paths.
# We can't use collection_cache for Session because of its special # We can't use collection_cache for Session because of its special
# role as the bootstrapping collector. # role as the bootstrapping collector.
path_cache: Dict[Path, Sequence[nodes.Collector]] = {} path_cache: dict[Path, Sequence[nodes.Collector]] = {}
pm = self.config.pluginmanager pm = self.config.pluginmanager
for collection_argument in self._initial_parts: for collection_argument in self._initial_parts:
self.trace("processing argument", collection_argument) self.trace('processing argument', collection_argument)
self.trace.root.indent += 1 self.trace.root.indent += 1
argpath = collection_argument.path argpath = collection_argument.path
@ -858,7 +857,7 @@ class Session(nodes.Collector):
# resolve_collection_argument() ensures this. # resolve_collection_argument() ensures this.
if argpath.is_dir(): if argpath.is_dir():
assert not names, f"invalid arg {(argpath, names)!r}" assert not names, f'invalid arg {(argpath, names)!r}'
paths = [argpath] paths = [argpath]
# Add relevant parents of the path, from the root, e.g. # Add relevant parents of the path, from the root, e.g.
@ -872,7 +871,7 @@ class Session(nodes.Collector):
else: else:
# For --pyargs arguments, only consider paths matching the module # For --pyargs arguments, only consider paths matching the module
# name. Paths beyond the package hierarchy are not included. # name. Paths beyond the package hierarchy are not included.
module_name_parts = module_name.split(".") module_name_parts = module_name.split('.')
for i, path in enumerate(argpath.parents, 2): for i, path in enumerate(argpath.parents, 2):
if i > len(module_name_parts) or path.stem != module_name_parts[-i]: if i > len(module_name_parts) or path.stem != module_name_parts[-i]:
break break
@ -882,8 +881,8 @@ class Session(nodes.Collector):
# and discarding all nodes which don't match the level's part. # and discarding all nodes which don't match the level's part.
any_matched_in_initial_part = False any_matched_in_initial_part = False
notfound_collectors = [] notfound_collectors = []
work: List[ work: list[
Tuple[Union[nodes.Collector, nodes.Item], List[Union[Path, str]]] tuple[nodes.Collector | nodes.Item, list[Path | str]]
] = [(self, [*paths, *names])] ] = [(self, [*paths, *names])]
while work: while work:
matchnode, matchparts = work.pop() matchnode, matchparts = work.pop()
@ -901,7 +900,7 @@ class Session(nodes.Collector):
# Collect this level of matching. # Collect this level of matching.
# Collecting Session (self) is done directly to avoid endless # Collecting Session (self) is done directly to avoid endless
# recursion to this function. # recursion to this function.
subnodes: Sequence[Union[nodes.Collector, nodes.Item]] subnodes: Sequence[nodes.Collector | nodes.Item]
if isinstance(matchnode, Session): if isinstance(matchnode, Session):
assert isinstance(matchparts[0], Path) assert isinstance(matchparts[0], Path)
subnodes = matchnode._collect_path(matchparts[0], path_cache) subnodes = matchnode._collect_path(matchparts[0], path_cache)
@ -909,9 +908,9 @@ class Session(nodes.Collector):
# For backward compat, files given directly multiple # For backward compat, files given directly multiple
# times on the command line should not be deduplicated. # times on the command line should not be deduplicated.
handle_dupes = not ( handle_dupes = not (
len(matchparts) == 1 len(matchparts) == 1 and
and isinstance(matchparts[0], Path) isinstance(matchparts[0], Path) and
and matchparts[0].is_file() matchparts[0].is_file()
) )
rep, duplicate = self._collect_one_node(matchnode, handle_dupes) rep, duplicate = self._collect_one_node(matchnode, handle_dupes)
if not duplicate and not rep.passed: if not duplicate and not rep.passed:
@ -929,15 +928,15 @@ class Session(nodes.Collector):
# Path part e.g. `/a/b/` in `/a/b/test_file.py::TestIt::test_it`. # Path part e.g. `/a/b/` in `/a/b/test_file.py::TestIt::test_it`.
if isinstance(matchparts[0], Path): if isinstance(matchparts[0], Path):
is_match = node.path == matchparts[0] is_match = node.path == matchparts[0]
if sys.platform == "win32" and not is_match: if sys.platform == 'win32' and not is_match:
# In case the file paths do not match, fallback to samefile() to # In case the file paths do not match, fallback to samefile() to
# account for short-paths on Windows (#11895). # account for short-paths on Windows (#11895).
same_file = os.path.samefile(node.path, matchparts[0]) same_file = os.path.samefile(node.path, matchparts[0])
# We don't want to match links to the current node, # We don't want to match links to the current node,
# otherwise we would match the same file more than once (#12039). # otherwise we would match the same file more than once (#12039).
is_match = same_file and ( is_match = same_file and (
os.path.islink(node.path) os.path.islink(node.path) ==
== os.path.islink(matchparts[0]) os.path.islink(matchparts[0])
) )
# Name part e.g. `TestIt` in `/a/b/test_file.py::TestIt::test_it`. # Name part e.g. `TestIt` in `/a/b/test_file.py::TestIt::test_it`.
@ -945,8 +944,8 @@ class Session(nodes.Collector):
# TODO: Remove parametrized workaround once collection structure contains # TODO: Remove parametrized workaround once collection structure contains
# parametrization. # parametrization.
is_match = ( is_match = (
node.name == matchparts[0] node.name == matchparts[0] or
or node.name.split("[")[0] == matchparts[0] node.name.split('[')[0] == matchparts[0]
) )
if is_match: if is_match:
work.append((node, matchparts[1:])) work.append((node, matchparts[1:]))
@ -956,21 +955,21 @@ class Session(nodes.Collector):
notfound_collectors.append(matchnode) notfound_collectors.append(matchnode)
if not any_matched_in_initial_part: if not any_matched_in_initial_part:
report_arg = "::".join((str(argpath), *names)) report_arg = '::'.join((str(argpath), *names))
self._notfound.append((report_arg, notfound_collectors)) self._notfound.append((report_arg, notfound_collectors))
self.trace.root.indent -= 1 self.trace.root.indent -= 1
def genitems( def genitems(
self, node: Union[nodes.Item, nodes.Collector] self, node: nodes.Item | nodes.Collector,
) -> Iterator[nodes.Item]: ) -> Iterator[nodes.Item]:
self.trace("genitems", node) self.trace('genitems', node)
if isinstance(node, nodes.Item): if isinstance(node, nodes.Item):
node.ihook.pytest_itemcollected(item=node) node.ihook.pytest_itemcollected(item=node)
yield node yield node
else: else:
assert isinstance(node, nodes.Collector) assert isinstance(node, nodes.Collector)
keepduplicates = self.config.getoption("keepduplicates") keepduplicates = self.config.getoption('keepduplicates')
# For backward compat, dedup only applies to files. # For backward compat, dedup only applies to files.
handle_dupes = not (keepduplicates and isinstance(node, nodes.File)) handle_dupes = not (keepduplicates and isinstance(node, nodes.File))
rep, duplicate = self._collect_one_node(node, handle_dupes) rep, duplicate = self._collect_one_node(node, handle_dupes)
@ -983,7 +982,7 @@ class Session(nodes.Collector):
node.ihook.pytest_collectreport(report=rep) node.ihook.pytest_collectreport(report=rep)
def search_pypath(module_name: str) -> Optional[str]: def search_pypath(module_name: str) -> str | None:
"""Search sys.path for the given a dotted module name, and return its file """Search sys.path for the given a dotted module name, and return its file
system path if found.""" system path if found."""
try: try:
@ -993,7 +992,7 @@ def search_pypath(module_name: str) -> Optional[str]:
# ValueError: not a module name # ValueError: not a module name
except (AttributeError, ImportError, ValueError): except (AttributeError, ImportError, ValueError):
return None return None
if spec is None or spec.origin is None or spec.origin == "namespace": if spec is None or spec.origin is None or spec.origin == 'namespace':
return None return None
elif spec.submodule_search_locations: elif spec.submodule_search_locations:
return os.path.dirname(spec.origin) return os.path.dirname(spec.origin)
@ -1007,11 +1006,11 @@ class CollectionArgument:
path: Path path: Path
parts: Sequence[str] parts: Sequence[str]
module_name: Optional[str] module_name: str | None
def resolve_collection_argument( def resolve_collection_argument(
invocation_path: Path, arg: str, *, as_pypath: bool = False invocation_path: Path, arg: str, *, as_pypath: bool = False,
) -> CollectionArgument: ) -> CollectionArgument:
"""Parse path arguments optionally containing selection parts and return (fspath, names). """Parse path arguments optionally containing selection parts and return (fspath, names).
@ -1045,10 +1044,10 @@ def resolve_collection_argument(
If the path doesn't exist, raise UsageError. If the path doesn't exist, raise UsageError.
If the path is a directory and selection parts are present, raise UsageError. If the path is a directory and selection parts are present, raise UsageError.
""" """
base, squacket, rest = str(arg).partition("[") base, squacket, rest = str(arg).partition('[')
strpath, *parts = base.split("::") strpath, *parts = base.split('::')
if parts: if parts:
parts[-1] = f"{parts[-1]}{squacket}{rest}" parts[-1] = f'{parts[-1]}{squacket}{rest}'
module_name = None module_name = None
if as_pypath: if as_pypath:
pyarg_strpath = search_pypath(strpath) pyarg_strpath = search_pypath(strpath)
@ -1059,16 +1058,16 @@ def resolve_collection_argument(
fspath = absolutepath(fspath) fspath = absolutepath(fspath)
if not safe_exists(fspath): if not safe_exists(fspath):
msg = ( msg = (
"module or package not found: {arg} (missing __init__.py?)" 'module or package not found: {arg} (missing __init__.py?)'
if as_pypath if as_pypath
else "file or directory not found: {arg}" else 'file or directory not found: {arg}'
) )
raise UsageError(msg.format(arg=arg)) raise UsageError(msg.format(arg=arg))
if parts and fspath.is_dir(): if parts and fspath.is_dir():
msg = ( msg = (
"package argument cannot contain :: selection parts: {arg}" 'package argument cannot contain :: selection parts: {arg}'
if as_pypath if as_pypath
else "directory argument cannot contain :: selection parts: {arg}" else 'directory argument cannot contain :: selection parts: {arg}'
) )
raise UsageError(msg.format(arg=arg)) raise UsageError(msg.format(arg=arg))
return CollectionArgument( return CollectionArgument(

View file

@ -1,4 +1,5 @@
"""Generic mechanism for marking and selecting python functions.""" """Generic mechanism for marking and selecting python functions."""
from __future__ import annotations
import dataclasses import dataclasses
from typing import AbstractSet from typing import AbstractSet
@ -8,6 +9,13 @@ from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union from typing import Union
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey
from .expression import Expression from .expression import Expression
from .expression import ParseError from .expression import ParseError
from .structures import EMPTY_PARAMETERSET_OPTION from .structures import EMPTY_PARAMETERSET_OPTION
@ -17,12 +25,6 @@ from .structures import MARK_GEN
from .structures import MarkDecorator from .structures import MarkDecorator
from .structures import MarkGenerator from .structures import MarkGenerator
from .structures import ParameterSet from .structures import ParameterSet
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey
if TYPE_CHECKING: if TYPE_CHECKING:
@ -30,12 +32,12 @@ if TYPE_CHECKING:
__all__ = [ __all__ = [
"MARK_GEN", 'MARK_GEN',
"Mark", 'Mark',
"MarkDecorator", 'MarkDecorator',
"MarkGenerator", 'MarkGenerator',
"ParameterSet", 'ParameterSet',
"get_empty_parameterset_mark", 'get_empty_parameterset_mark',
] ]
@ -44,8 +46,8 @@ old_mark_config_key = StashKey[Optional[Config]]()
def param( def param(
*values: object, *values: object,
marks: Union[MarkDecorator, Collection[Union[MarkDecorator, Mark]]] = (), marks: MarkDecorator | Collection[MarkDecorator | Mark] = (),
id: Optional[str] = None, id: str | None = None,
) -> ParameterSet: ) -> ParameterSet:
"""Specify a parameter in `pytest.mark.parametrize`_ calls or """Specify a parameter in `pytest.mark.parametrize`_ calls or
:ref:`parametrized fixtures <fixture-parametrize-marks>`. :ref:`parametrized fixtures <fixture-parametrize-marks>`.
@ -70,59 +72,59 @@ def param(
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general") group = parser.getgroup('general')
group._addoption( group._addoption(
"-k", '-k',
action="store", action='store',
dest="keyword", dest='keyword',
default="", default='',
metavar="EXPRESSION", metavar='EXPRESSION',
help="Only run tests which match the given substring expression. " help='Only run tests which match the given substring expression. '
"An expression is a Python evaluatable expression " 'An expression is a Python evaluatable expression '
"where all names are substring-matched against test names " 'where all names are substring-matched against test names '
"and their parent classes. Example: -k 'test_method or test_" "and their parent classes. Example: -k 'test_method or test_"
"other' matches all test functions and classes whose name " "other' matches all test functions and classes whose name "
"contains 'test_method' or 'test_other', while -k 'not test_method' " "contains 'test_method' or 'test_other', while -k 'not test_method' "
"matches those that don't contain 'test_method' in their names. " "matches those that don't contain 'test_method' in their names. "
"-k 'not test_method and not test_other' will eliminate the matches. " "-k 'not test_method and not test_other' will eliminate the matches. "
"Additionally keywords are matched to classes and functions " 'Additionally keywords are matched to classes and functions '
"containing extra names in their 'extra_keyword_matches' set, " "containing extra names in their 'extra_keyword_matches' set, "
"as well as functions which have names assigned directly to them. " 'as well as functions which have names assigned directly to them. '
"The matching is case-insensitive.", 'The matching is case-insensitive.',
) )
group._addoption( group._addoption(
"-m", '-m',
action="store", action='store',
dest="markexpr", dest='markexpr',
default="", default='',
metavar="MARKEXPR", metavar='MARKEXPR',
help="Only run tests matching given mark expression. " help='Only run tests matching given mark expression. '
"For example: -m 'mark1 and not mark2'.", "For example: -m 'mark1 and not mark2'.",
) )
group.addoption( group.addoption(
"--markers", '--markers',
action="store_true", action='store_true',
help="show markers (builtin, plugin and per-project ones).", help='show markers (builtin, plugin and per-project ones).',
) )
parser.addini("markers", "Register new markers for test functions", "linelist") parser.addini('markers', 'Register new markers for test functions', 'linelist')
parser.addini(EMPTY_PARAMETERSET_OPTION, "Default marker for empty parametersets") parser.addini(EMPTY_PARAMETERSET_OPTION, 'Default marker for empty parametersets')
@hookimpl(tryfirst=True) @hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
import _pytest.config import _pytest.config
if config.option.markers: if config.option.markers:
config._do_configure() config._do_configure()
tw = _pytest.config.create_terminal_writer(config) tw = _pytest.config.create_terminal_writer(config)
for line in config.getini("markers"): for line in config.getini('markers'):
parts = line.split(":", 1) parts = line.split(':', 1)
name = parts[0] name = parts[0]
rest = parts[1] if len(parts) == 2 else "" rest = parts[1] if len(parts) == 2 else ''
tw.write("@pytest.mark.%s:" % name, bold=True) tw.write('@pytest.mark.%s:' % name, bold=True)
tw.line(rest) tw.line(rest)
tw.line() tw.line()
config._ensure_unconfigure() config._ensure_unconfigure()
@ -146,12 +148,12 @@ class KeywordMatcher:
any item, as well as names directly assigned to test functions. any item, as well as names directly assigned to test functions.
""" """
__slots__ = ("_names",) __slots__ = ('_names',)
_names: AbstractSet[str] _names: AbstractSet[str]
@classmethod @classmethod
def from_item(cls, item: "Item") -> "KeywordMatcher": def from_item(cls, item: Item) -> KeywordMatcher:
mapped_names = set() mapped_names = set()
# Add the names of the current item and any parent items, # Add the names of the current item and any parent items,
@ -163,7 +165,7 @@ class KeywordMatcher:
if isinstance(node, pytest.Session): if isinstance(node, pytest.Session):
continue continue
if isinstance(node, pytest.Directory) and isinstance( if isinstance(node, pytest.Directory) and isinstance(
node.parent, pytest.Session node.parent, pytest.Session,
): ):
continue continue
mapped_names.add(node.name) mapped_names.add(node.name)
@ -172,7 +174,7 @@ class KeywordMatcher:
mapped_names.update(item.listextrakeywords()) mapped_names.update(item.listextrakeywords())
# Add the names attached to the current function through direct assignment. # Add the names attached to the current function through direct assignment.
function_obj = getattr(item, "function", None) function_obj = getattr(item, 'function', None)
if function_obj: if function_obj:
mapped_names.update(function_obj.__dict__) mapped_names.update(function_obj.__dict__)
@ -191,7 +193,7 @@ class KeywordMatcher:
return False return False
def deselect_by_keyword(items: "List[Item]", config: Config) -> None: def deselect_by_keyword(items: List[Item], config: Config) -> None:
keywordexpr = config.option.keyword.lstrip() keywordexpr = config.option.keyword.lstrip()
if not keywordexpr: if not keywordexpr:
return return
@ -218,12 +220,12 @@ class MarkMatcher:
Tries to match on any marker names, attached to the given colitem. Tries to match on any marker names, attached to the given colitem.
""" """
__slots__ = ("own_mark_names",) __slots__ = ('own_mark_names',)
own_mark_names: AbstractSet[str] own_mark_names: AbstractSet[str]
@classmethod @classmethod
def from_item(cls, item: "Item") -> "MarkMatcher": def from_item(cls, item: Item) -> MarkMatcher:
mark_names = {mark.name for mark in item.iter_markers()} mark_names = {mark.name for mark in item.iter_markers()}
return cls(mark_names) return cls(mark_names)
@ -231,14 +233,14 @@ class MarkMatcher:
return name in self.own_mark_names return name in self.own_mark_names
def deselect_by_mark(items: "List[Item]", config: Config) -> None: def deselect_by_mark(items: List[Item], config: Config) -> None:
matchexpr = config.option.markexpr matchexpr = config.option.markexpr
if not matchexpr: if not matchexpr:
return return
expr = _parse_expression(matchexpr, "Wrong expression passed to '-m'") expr = _parse_expression(matchexpr, "Wrong expression passed to '-m'")
remaining: List[Item] = [] remaining: list[Item] = []
deselected: List[Item] = [] deselected: list[Item] = []
for item in items: for item in items:
if expr.evaluate(MarkMatcher.from_item(item)): if expr.evaluate(MarkMatcher.from_item(item)):
remaining.append(item) remaining.append(item)
@ -253,10 +255,10 @@ def _parse_expression(expr: str, exc_message: str) -> Expression:
try: try:
return Expression.compile(expr) return Expression.compile(expr)
except ParseError as e: except ParseError as e:
raise UsageError(f"{exc_message}: {expr}: {e}") from None raise UsageError(f'{exc_message}: {expr}: {e}') from None
def pytest_collection_modifyitems(items: "List[Item]", config: Config) -> None: def pytest_collection_modifyitems(items: List[Item], config: Config) -> None:
deselect_by_keyword(items, config) deselect_by_keyword(items, config)
deselect_by_mark(items, config) deselect_by_mark(items, config)
@ -267,10 +269,10 @@ def pytest_configure(config: Config) -> None:
empty_parameterset = config.getini(EMPTY_PARAMETERSET_OPTION) empty_parameterset = config.getini(EMPTY_PARAMETERSET_OPTION)
if empty_parameterset not in ("skip", "xfail", "fail_at_collect", None, ""): if empty_parameterset not in ('skip', 'xfail', 'fail_at_collect', None, ''):
raise UsageError( raise UsageError(
f"{EMPTY_PARAMETERSET_OPTION!s} must be one of skip, xfail or fail_at_collect" f'{EMPTY_PARAMETERSET_OPTION!s} must be one of skip, xfail or fail_at_collect'
f" but it is {empty_parameterset!r}" f' but it is {empty_parameterset!r}',
) )

View file

@ -14,6 +14,7 @@ The semantics are:
- ident evaluates to True of False according to a provided matcher function. - ident evaluates to True of False according to a provided matcher function.
- or/and/not evaluate according to the usual boolean semantics. - or/and/not evaluate according to the usual boolean semantics.
""" """
from __future__ import annotations
import ast import ast
import dataclasses import dataclasses
@ -29,24 +30,24 @@ from typing import Sequence
__all__ = [ __all__ = [
"Expression", 'Expression',
"ParseError", 'ParseError',
] ]
class TokenType(enum.Enum): class TokenType(enum.Enum):
LPAREN = "left parenthesis" LPAREN = 'left parenthesis'
RPAREN = "right parenthesis" RPAREN = 'right parenthesis'
OR = "or" OR = 'or'
AND = "and" AND = 'and'
NOT = "not" NOT = 'not'
IDENT = "identifier" IDENT = 'identifier'
EOF = "end of input" EOF = 'end of input'
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Token: class Token:
__slots__ = ("type", "value", "pos") __slots__ = ('type', 'value', 'pos')
type: TokenType type: TokenType
value: str value: str
pos: int pos: int
@ -64,11 +65,11 @@ class ParseError(Exception):
self.message = message self.message = message
def __str__(self) -> str: def __str__(self) -> str:
return f"at column {self.column}: {self.message}" return f'at column {self.column}: {self.message}'
class Scanner: class Scanner:
__slots__ = ("tokens", "current") __slots__ = ('tokens', 'current')
def __init__(self, input: str) -> None: def __init__(self, input: str) -> None:
self.tokens = self.lex(input) self.tokens = self.lex(input)
@ -77,23 +78,23 @@ class Scanner:
def lex(self, input: str) -> Iterator[Token]: def lex(self, input: str) -> Iterator[Token]:
pos = 0 pos = 0
while pos < len(input): while pos < len(input):
if input[pos] in (" ", "\t"): if input[pos] in (' ', '\t'):
pos += 1 pos += 1
elif input[pos] == "(": elif input[pos] == '(':
yield Token(TokenType.LPAREN, "(", pos) yield Token(TokenType.LPAREN, '(', pos)
pos += 1 pos += 1
elif input[pos] == ")": elif input[pos] == ')':
yield Token(TokenType.RPAREN, ")", pos) yield Token(TokenType.RPAREN, ')', pos)
pos += 1 pos += 1
else: else:
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:]) match = re.match(r'(:?\w|:|\+|-|\.|\[|\]|\\|/)+', input[pos:])
if match: if match:
value = match.group(0) value = match.group(0)
if value == "or": if value == 'or':
yield Token(TokenType.OR, value, pos) yield Token(TokenType.OR, value, pos)
elif value == "and": elif value == 'and':
yield Token(TokenType.AND, value, pos) yield Token(TokenType.AND, value, pos)
elif value == "not": elif value == 'not':
yield Token(TokenType.NOT, value, pos) yield Token(TokenType.NOT, value, pos)
else: else:
yield Token(TokenType.IDENT, value, pos) yield Token(TokenType.IDENT, value, pos)
@ -103,9 +104,9 @@ class Scanner:
pos + 1, pos + 1,
f'unexpected character "{input[pos]}"', f'unexpected character "{input[pos]}"',
) )
yield Token(TokenType.EOF, "", pos) yield Token(TokenType.EOF, '', pos)
def accept(self, type: TokenType, *, reject: bool = False) -> Optional[Token]: def accept(self, type: TokenType, *, reject: bool = False) -> Token | None:
if self.current.type is type: if self.current.type is type:
token = self.current token = self.current
if token.type is not TokenType.EOF: if token.type is not TokenType.EOF:
@ -118,8 +119,8 @@ class Scanner:
def reject(self, expected: Sequence[TokenType]) -> NoReturn: def reject(self, expected: Sequence[TokenType]) -> NoReturn:
raise ParseError( raise ParseError(
self.current.pos + 1, self.current.pos + 1,
"expected {}; got {}".format( 'expected {}; got {}'.format(
" OR ".join(type.value for type in expected), ' OR '.join(type.value for type in expected),
self.current.type.value, self.current.type.value,
), ),
) )
@ -128,7 +129,7 @@ class Scanner:
# True, False and None are legal match expression identifiers, # True, False and None are legal match expression identifiers,
# but illegal as Python identifiers. To fix this, this prefix # but illegal as Python identifiers. To fix this, this prefix
# is added to identifiers in the conversion to Python AST. # is added to identifiers in the conversion to Python AST.
IDENT_PREFIX = "$" IDENT_PREFIX = '$'
def expression(s: Scanner) -> ast.Expression: def expression(s: Scanner) -> ast.Expression:
@ -176,7 +177,7 @@ class MatcherAdapter(Mapping[str, bool]):
self.matcher = matcher self.matcher = matcher
def __getitem__(self, key: str) -> bool: def __getitem__(self, key: str) -> bool:
return self.matcher(key[len(IDENT_PREFIX) :]) return self.matcher(key[len(IDENT_PREFIX):])
def __iter__(self) -> Iterator[str]: def __iter__(self) -> Iterator[str]:
raise NotImplementedError() raise NotImplementedError()
@ -191,13 +192,13 @@ class Expression:
The expression can be evaluated against different matchers. The expression can be evaluated against different matchers.
""" """
__slots__ = ("code",) __slots__ = ('code',)
def __init__(self, code: types.CodeType) -> None: def __init__(self, code: types.CodeType) -> None:
self.code = code self.code = code
@classmethod @classmethod
def compile(self, input: str) -> "Expression": def compile(self, input: str) -> Expression:
"""Compile a match expression. """Compile a match expression.
:param input: The input expression - one line. :param input: The input expression - one line.
@ -205,8 +206,8 @@ class Expression:
astexpr = expression(Scanner(input)) astexpr = expression(Scanner(input))
code: types.CodeType = compile( code: types.CodeType = compile(
astexpr, astexpr,
filename="<pytest match expression>", filename='<pytest match expression>',
mode="eval", mode='eval',
) )
return Expression(code) return Expression(code)
@ -219,5 +220,5 @@ class Expression:
:returns: Whether the expression matches or not. :returns: Whether the expression matches or not.
""" """
ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)) ret: bool = eval(self.code, {'__builtins__': {}}, MatcherAdapter(matcher))
return ret return ret

View file

@ -1,7 +1,10 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import collections.abc import collections.abc
import dataclasses import dataclasses
import inspect import inspect
import warnings
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Collection from typing import Collection
@ -21,37 +24,37 @@ from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
import warnings
from .._code import getfslineno
from ..compat import ascii_escaped
from ..compat import NOTSET
from ..compat import NotSetType
from _pytest.config import Config from _pytest.config import Config
from _pytest.deprecated import check_ispytest from _pytest.deprecated import check_ispytest
from _pytest.deprecated import MARKED_FIXTURE from _pytest.deprecated import MARKED_FIXTURE
from _pytest.outcomes import fail from _pytest.outcomes import fail
from _pytest.warning_types import PytestUnknownMarkWarning from _pytest.warning_types import PytestUnknownMarkWarning
from .._code import getfslineno
from ..compat import ascii_escaped
from ..compat import NOTSET
from ..compat import NotSetType
if TYPE_CHECKING: if TYPE_CHECKING:
from ..nodes import Node from ..nodes import Node
EMPTY_PARAMETERSET_OPTION = "empty_parameter_set_mark" EMPTY_PARAMETERSET_OPTION = 'empty_parameter_set_mark'
def istestfunc(func) -> bool: def istestfunc(func) -> bool:
return callable(func) and getattr(func, "__name__", "<lambda>") != "<lambda>" return callable(func) and getattr(func, '__name__', '<lambda>') != '<lambda>'
def get_empty_parameterset_mark( def get_empty_parameterset_mark(
config: Config, argnames: Sequence[str], func config: Config, argnames: Sequence[str], func,
) -> "MarkDecorator": ) -> MarkDecorator:
from ..nodes import Collector from ..nodes import Collector
fs, lineno = getfslineno(func) fs, lineno = getfslineno(func)
reason = "got empty parameter set %r, function %s at %s:%d" % ( reason = 'got empty parameter set %r, function %s at %s:%d' % (
argnames, argnames,
func.__name__, func.__name__,
fs, fs,
@ -59,15 +62,15 @@ def get_empty_parameterset_mark(
) )
requested_mark = config.getini(EMPTY_PARAMETERSET_OPTION) requested_mark = config.getini(EMPTY_PARAMETERSET_OPTION)
if requested_mark in ("", None, "skip"): if requested_mark in ('', None, 'skip'):
mark = MARK_GEN.skip(reason=reason) mark = MARK_GEN.skip(reason=reason)
elif requested_mark == "xfail": elif requested_mark == 'xfail':
mark = MARK_GEN.xfail(reason=reason, run=False) mark = MARK_GEN.xfail(reason=reason, run=False)
elif requested_mark == "fail_at_collect": elif requested_mark == 'fail_at_collect':
f_name = func.__name__ f_name = func.__name__
_, lineno = getfslineno(func) _, lineno = getfslineno(func)
raise Collector.CollectError( raise Collector.CollectError(
"Empty parameter set in '%s' at line %d" % (f_name, lineno + 1) "Empty parameter set in '%s' at line %d" % (f_name, lineno + 1),
) )
else: else:
raise LookupError(requested_mark) raise LookupError(requested_mark)
@ -75,17 +78,17 @@ def get_empty_parameterset_mark(
class ParameterSet(NamedTuple): class ParameterSet(NamedTuple):
values: Sequence[Union[object, NotSetType]] values: Sequence[object | NotSetType]
marks: Collection[Union["MarkDecorator", "Mark"]] marks: Collection[MarkDecorator | Mark]
id: Optional[str] id: str | None
@classmethod @classmethod
def param( def param(
cls, cls,
*values: object, *values: object,
marks: Union["MarkDecorator", Collection[Union["MarkDecorator", "Mark"]]] = (), marks: MarkDecorator | Collection[MarkDecorator | Mark] = (),
id: Optional[str] = None, id: str | None = None,
) -> "ParameterSet": ) -> ParameterSet:
if isinstance(marks, MarkDecorator): if isinstance(marks, MarkDecorator):
marks = (marks,) marks = (marks,)
else: else:
@ -93,16 +96,16 @@ class ParameterSet(NamedTuple):
if id is not None: if id is not None:
if not isinstance(id, str): if not isinstance(id, str):
raise TypeError(f"Expected id to be a string, got {type(id)}: {id!r}") raise TypeError(f'Expected id to be a string, got {type(id)}: {id!r}')
id = ascii_escaped(id) id = ascii_escaped(id)
return cls(values, marks, id) return cls(values, marks, id)
@classmethod @classmethod
def extract_from( def extract_from(
cls, cls,
parameterset: Union["ParameterSet", Sequence[object], object], parameterset: ParameterSet | Sequence[object] | object,
force_tuple: bool = False, force_tuple: bool = False,
) -> "ParameterSet": ) -> ParameterSet:
"""Extract from an object or objects. """Extract from an object or objects.
:param parameterset: :param parameterset:
@ -127,13 +130,13 @@ class ParameterSet(NamedTuple):
@staticmethod @staticmethod
def _parse_parametrize_args( def _parse_parametrize_args(
argnames: Union[str, Sequence[str]], argnames: str | Sequence[str],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]], argvalues: Iterable[ParameterSet | Sequence[object] | object],
*args, *args,
**kwargs, **kwargs,
) -> Tuple[Sequence[str], bool]: ) -> tuple[Sequence[str], bool]:
if isinstance(argnames, str): if isinstance(argnames, str):
argnames = [x.strip() for x in argnames.split(",") if x.strip()] argnames = [x.strip() for x in argnames.split(',') if x.strip()]
force_tuple = len(argnames) == 1 force_tuple = len(argnames) == 1
else: else:
force_tuple = False force_tuple = False
@ -141,9 +144,9 @@ class ParameterSet(NamedTuple):
@staticmethod @staticmethod
def _parse_parametrize_parameters( def _parse_parametrize_parameters(
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]], argvalues: Iterable[ParameterSet | Sequence[object] | object],
force_tuple: bool, force_tuple: bool,
) -> List["ParameterSet"]: ) -> list[ParameterSet]:
return [ return [
ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
] ]
@ -151,12 +154,12 @@ class ParameterSet(NamedTuple):
@classmethod @classmethod
def _for_parametrize( def _for_parametrize(
cls, cls,
argnames: Union[str, Sequence[str]], argnames: str | Sequence[str],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]], argvalues: Iterable[ParameterSet | Sequence[object] | object],
func, func,
config: Config, config: Config,
nodeid: str, nodeid: str,
) -> Tuple[Sequence[str], List["ParameterSet"]]: ) -> tuple[Sequence[str], list[ParameterSet]]:
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues) argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple) parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
del argvalues del argvalues
@ -167,9 +170,9 @@ class ParameterSet(NamedTuple):
if len(param.values) != len(argnames): if len(param.values) != len(argnames):
msg = ( msg = (
'{nodeid}: in "parametrize" the number of names ({names_len}):\n' '{nodeid}: in "parametrize" the number of names ({names_len}):\n'
" {names}\n" ' {names}\n'
"must be equal to the number of values ({values_len}):\n" 'must be equal to the number of values ({values_len}):\n'
" {values}" ' {values}'
) )
fail( fail(
msg.format( msg.format(
@ -186,7 +189,7 @@ class ParameterSet(NamedTuple):
# parameter set with NOTSET values, with the "empty parameter set" mark applied to it. # parameter set with NOTSET values, with the "empty parameter set" mark applied to it.
mark = get_empty_parameterset_mark(config, argnames, func) mark = get_empty_parameterset_mark(config, argnames, func)
parameters.append( parameters.append(
ParameterSet(values=(NOTSET,) * len(argnames), marks=[mark], id=None) ParameterSet(values=(NOTSET,) * len(argnames), marks=[mark], id=None),
) )
return argnames, parameters return argnames, parameters
@ -199,40 +202,40 @@ class Mark:
#: Name of the mark. #: Name of the mark.
name: str name: str
#: Positional arguments of the mark decorator. #: Positional arguments of the mark decorator.
args: Tuple[Any, ...] args: tuple[Any, ...]
#: Keyword arguments of the mark decorator. #: Keyword arguments of the mark decorator.
kwargs: Mapping[str, Any] kwargs: Mapping[str, Any]
#: Source Mark for ids with parametrize Marks. #: Source Mark for ids with parametrize Marks.
_param_ids_from: Optional["Mark"] = dataclasses.field(default=None, repr=False) _param_ids_from: Mark | None = dataclasses.field(default=None, repr=False)
#: Resolved/generated ids with parametrize Marks. #: Resolved/generated ids with parametrize Marks.
_param_ids_generated: Optional[Sequence[str]] = dataclasses.field( _param_ids_generated: Sequence[str] | None = dataclasses.field(
default=None, repr=False default=None, repr=False,
) )
def __init__( def __init__(
self, self,
name: str, name: str,
args: Tuple[Any, ...], args: tuple[Any, ...],
kwargs: Mapping[str, Any], kwargs: Mapping[str, Any],
param_ids_from: Optional["Mark"] = None, param_ids_from: Mark | None = None,
param_ids_generated: Optional[Sequence[str]] = None, param_ids_generated: Sequence[str] | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
""":meta private:""" """:meta private:"""
check_ispytest(_ispytest) check_ispytest(_ispytest)
# Weirdness to bypass frozen=True. # Weirdness to bypass frozen=True.
object.__setattr__(self, "name", name) object.__setattr__(self, 'name', name)
object.__setattr__(self, "args", args) object.__setattr__(self, 'args', args)
object.__setattr__(self, "kwargs", kwargs) object.__setattr__(self, 'kwargs', kwargs)
object.__setattr__(self, "_param_ids_from", param_ids_from) object.__setattr__(self, '_param_ids_from', param_ids_from)
object.__setattr__(self, "_param_ids_generated", param_ids_generated) object.__setattr__(self, '_param_ids_generated', param_ids_generated)
def _has_param_ids(self) -> bool: def _has_param_ids(self) -> bool:
return "ids" in self.kwargs or len(self.args) >= 4 return 'ids' in self.kwargs or len(self.args) >= 4
def combined_with(self, other: "Mark") -> "Mark": def combined_with(self, other: Mark) -> Mark:
"""Return a new Mark which is a combination of this """Return a new Mark which is a combination of this
Mark and another Mark. Mark and another Mark.
@ -244,8 +247,8 @@ class Mark:
assert self.name == other.name assert self.name == other.name
# Remember source of ids with parametrize Marks. # Remember source of ids with parametrize Marks.
param_ids_from: Optional[Mark] = None param_ids_from: Mark | None = None
if self.name == "parametrize": if self.name == 'parametrize':
if other._has_param_ids(): if other._has_param_ids():
param_ids_from = other param_ids_from = other
elif self._has_param_ids(): elif self._has_param_ids():
@ -263,7 +266,7 @@ class Mark:
# A generic parameter designating an object to which a Mark may # A generic parameter designating an object to which a Mark may
# be applied -- a test function (callable) or class. # be applied -- a test function (callable) or class.
# Note: a lambda is not allowed, but this can't be represented. # Note: a lambda is not allowed, but this can't be represented.
Markable = TypeVar("Markable", bound=Union[Callable[..., object], type]) Markable = TypeVar('Markable', bound=Union[Callable[..., object], type])
@dataclasses.dataclass @dataclasses.dataclass
@ -315,7 +318,7 @@ class MarkDecorator:
return self.mark.name return self.mark.name
@property @property
def args(self) -> Tuple[Any, ...]: def args(self) -> tuple[Any, ...]:
"""Alias for mark.args.""" """Alias for mark.args."""
return self.mark.args return self.mark.args
@ -329,7 +332,7 @@ class MarkDecorator:
""":meta private:""" """:meta private:"""
return self.name # for backward-compat (2.4.1 had this attr) return self.name # for backward-compat (2.4.1 had this attr)
def with_args(self, *args: object, **kwargs: object) -> "MarkDecorator": def with_args(self, *args: object, **kwargs: object) -> MarkDecorator:
"""Return a MarkDecorator with extra arguments added. """Return a MarkDecorator with extra arguments added.
Unlike calling the MarkDecorator, with_args() can be used even Unlike calling the MarkDecorator, with_args() can be used even
@ -346,7 +349,7 @@ class MarkDecorator:
pass pass
@overload @overload
def __call__(self, *args: object, **kwargs: object) -> "MarkDecorator": def __call__(self, *args: object, **kwargs: object) -> MarkDecorator:
pass pass
def __call__(self, *args: object, **kwargs: object): def __call__(self, *args: object, **kwargs: object):
@ -361,10 +364,10 @@ class MarkDecorator:
def get_unpacked_marks( def get_unpacked_marks(
obj: Union[object, type], obj: object | type,
*, *,
consider_mro: bool = True, consider_mro: bool = True,
) -> List[Mark]: ) -> list[Mark]:
"""Obtain the unpacked marks that are stored on an object. """Obtain the unpacked marks that are stored on an object.
If obj is a class and consider_mro is true, return marks applied to If obj is a class and consider_mro is true, return marks applied to
@ -373,10 +376,10 @@ def get_unpacked_marks(
""" """
if isinstance(obj, type): if isinstance(obj, type):
if not consider_mro: if not consider_mro:
mark_lists = [obj.__dict__.get("pytestmark", [])] mark_lists = [obj.__dict__.get('pytestmark', [])]
else: else:
mark_lists = [ mark_lists = [
x.__dict__.get("pytestmark", []) for x in reversed(obj.__mro__) x.__dict__.get('pytestmark', []) for x in reversed(obj.__mro__)
] ]
mark_list = [] mark_list = []
for item in mark_lists: for item in mark_lists:
@ -385,7 +388,7 @@ def get_unpacked_marks(
else: else:
mark_list.append(item) mark_list.append(item)
else: else:
mark_attribute = getattr(obj, "pytestmark", []) mark_attribute = getattr(obj, 'pytestmark', [])
if isinstance(mark_attribute, list): if isinstance(mark_attribute, list):
mark_list = mark_attribute mark_list = mark_attribute
else: else:
@ -394,7 +397,7 @@ def get_unpacked_marks(
def normalize_mark_list( def normalize_mark_list(
mark_list: Iterable[Union[Mark, MarkDecorator]], mark_list: Iterable[Mark | MarkDecorator],
) -> Iterable[Mark]: ) -> Iterable[Mark]:
""" """
Normalize an iterable of Mark or MarkDecorator objects into a list of marks Normalize an iterable of Mark or MarkDecorator objects into a list of marks
@ -404,9 +407,9 @@ def normalize_mark_list(
:returns: A new list of the extracted Mark objects :returns: A new list of the extracted Mark objects
""" """
for mark in mark_list: for mark in mark_list:
mark_obj = getattr(mark, "mark", mark) mark_obj = getattr(mark, 'mark', mark)
if not isinstance(mark_obj, Mark): if not isinstance(mark_obj, Mark):
raise TypeError(f"got {mark_obj!r} instead of Mark") raise TypeError(f'got {mark_obj!r} instead of Mark')
yield mark_obj yield mark_obj
@ -438,14 +441,14 @@ if TYPE_CHECKING:
... ...
@overload @overload
def __call__(self, reason: str = ...) -> "MarkDecorator": def __call__(self, reason: str = ...) -> MarkDecorator:
... ...
class _SkipifMarkDecorator(MarkDecorator): class _SkipifMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override] def __call__( # type: ignore[override]
self, self,
condition: Union[str, bool] = ..., condition: str | bool = ...,
*conditions: Union[str, bool], *conditions: str | bool,
reason: str = ..., reason: str = ...,
) -> MarkDecorator: ) -> MarkDecorator:
... ...
@ -458,13 +461,13 @@ if TYPE_CHECKING:
@overload @overload
def __call__( def __call__(
self, self,
condition: Union[str, bool] = False, condition: str | bool = False,
*conditions: Union[str, bool], *conditions: str | bool,
reason: str = ..., reason: str = ...,
run: bool = ..., run: bool = ...,
raises: Union[ raises: (
None, Type[BaseException], Tuple[Type[BaseException], ...] None | type[BaseException] | tuple[type[BaseException], ...]
] = ..., ) = ...,
strict: bool = ..., strict: bool = ...,
) -> MarkDecorator: ) -> MarkDecorator:
... ...
@ -472,17 +475,15 @@ if TYPE_CHECKING:
class _ParametrizeMarkDecorator(MarkDecorator): class _ParametrizeMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override] def __call__( # type: ignore[override]
self, self,
argnames: Union[str, Sequence[str]], argnames: str | Sequence[str],
argvalues: Iterable[Union[ParameterSet, Sequence[object], object]], argvalues: Iterable[ParameterSet | Sequence[object] | object],
*, *,
indirect: Union[bool, Sequence[str]] = ..., indirect: bool | Sequence[str] = ...,
ids: Optional[ ids: None | (
Union[ Iterable[None | str | float | int | bool] |
Iterable[Union[None, str, float, int, bool]], Callable[[Any], object | None]
Callable[[Any], Optional[object]], ) = ...,
] scope: _ScopeName | None = ...,
] = ...,
scope: Optional[_ScopeName] = ...,
) -> MarkDecorator: ) -> MarkDecorator:
... ...
@ -523,24 +524,24 @@ class MarkGenerator:
def __init__(self, *, _ispytest: bool = False) -> None: def __init__(self, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self._config: Optional[Config] = None self._config: Config | None = None
self._markers: Set[str] = set() self._markers: set[str] = set()
def __getattr__(self, name: str) -> MarkDecorator: def __getattr__(self, name: str) -> MarkDecorator:
"""Generate a new :class:`MarkDecorator` with the given name.""" """Generate a new :class:`MarkDecorator` with the given name."""
if name[0] == "_": if name[0] == '_':
raise AttributeError("Marker name must NOT start with underscore") raise AttributeError('Marker name must NOT start with underscore')
if self._config is not None: if self._config is not None:
# We store a set of markers as a performance optimisation - if a mark # We store a set of markers as a performance optimisation - if a mark
# name is in the set we definitely know it, but a mark may be known and # name is in the set we definitely know it, but a mark may be known and
# not in the set. We therefore start by updating the set! # not in the set. We therefore start by updating the set!
if name not in self._markers: if name not in self._markers:
for line in self._config.getini("markers"): for line in self._config.getini('markers'):
# example lines: "skipif(condition): skip the given test if..." # example lines: "skipif(condition): skip the given test if..."
# or "hypothesis: tests which use Hypothesis", so to get the # or "hypothesis: tests which use Hypothesis", so to get the
# marker name we split on both `:` and `(`. # marker name we split on both `:` and `(`.
marker = line.split(":")[0].split("(")[0].strip() marker = line.split(':')[0].split('(')[0].strip()
self._markers.add(marker) self._markers.add(marker)
# If the name is not in the set of known marks after updating, # If the name is not in the set of known marks after updating,
@ -548,19 +549,19 @@ class MarkGenerator:
if name not in self._markers: if name not in self._markers:
if self._config.option.strict_markers or self._config.option.strict: if self._config.option.strict_markers or self._config.option.strict:
fail( fail(
f"{name!r} not found in `markers` configuration option", f'{name!r} not found in `markers` configuration option',
pytrace=False, pytrace=False,
) )
# Raise a specific error for common misspellings of "parametrize". # Raise a specific error for common misspellings of "parametrize".
if name in ["parameterize", "parametrise", "parameterise"]: if name in ['parameterize', 'parametrise', 'parameterise']:
__tracebackhide__ = True __tracebackhide__ = True
fail(f"Unknown '{name}' mark, did you mean 'parametrize'?") fail(f"Unknown '{name}' mark, did you mean 'parametrize'?")
warnings.warn( warnings.warn(
"Unknown pytest.mark.%s - is this a typo? You can register " 'Unknown pytest.mark.%s - is this a typo? You can register '
"custom marks to avoid this warning - for details, see " 'custom marks to avoid this warning - for details, see '
"https://docs.pytest.org/en/stable/how-to/mark.html" % name, 'https://docs.pytest.org/en/stable/how-to/mark.html' % name,
PytestUnknownMarkWarning, PytestUnknownMarkWarning,
2, 2,
) )
@ -573,9 +574,9 @@ MARK_GEN = MarkGenerator(_ispytest=True)
@final @final
class NodeKeywords(MutableMapping[str, Any]): class NodeKeywords(MutableMapping[str, Any]):
__slots__ = ("node", "parent", "_markers") __slots__ = ('node', 'parent', '_markers')
def __init__(self, node: "Node") -> None: def __init__(self, node: Node) -> None:
self.node = node self.node = node
self.parent = node.parent self.parent = node.parent
self._markers = {node.name: True} self._markers = {node.name: True}
@ -596,21 +597,21 @@ class NodeKeywords(MutableMapping[str, Any]):
def __contains__(self, key: object) -> bool: def __contains__(self, key: object) -> bool:
return ( return (
key in self._markers key in self._markers or
or self.parent is not None self.parent is not None and
and key in self.parent.keywords key in self.parent.keywords
) )
def update( # type: ignore[override] def update( # type: ignore[override]
self, self,
other: Union[Mapping[str, Any], Iterable[Tuple[str, Any]]] = (), other: Mapping[str, Any] | Iterable[tuple[str, Any]] = (),
**kwds: Any, **kwds: Any,
) -> None: ) -> None:
self._markers.update(other) self._markers.update(other)
self._markers.update(kwds) self._markers.update(kwds)
def __delitem__(self, key: str) -> None: def __delitem__(self, key: str) -> None:
raise ValueError("cannot delete key in keywords dict") raise ValueError('cannot delete key in keywords dict')
def __iter__(self) -> Iterator[str]: def __iter__(self) -> Iterator[str]:
# Doesn't need to be fast. # Doesn't need to be fast.
@ -626,4 +627,4 @@ class NodeKeywords(MutableMapping[str, Any]):
return sum(1 for keyword in self) return sum(1 for keyword in self)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<NodeKeywords for node {self.node}>" return f'<NodeKeywords for node {self.node}>'

View file

@ -1,9 +1,12 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Monkeypatching and mocking functionality.""" """Monkeypatching and mocking functionality."""
from contextlib import contextmanager from __future__ import annotations
import os import os
import re import re
import sys import sys
import warnings
from contextlib import contextmanager
from typing import Any from typing import Any
from typing import final from typing import final
from typing import Generator from typing import Generator
@ -15,21 +18,20 @@ from typing import overload
from typing import Tuple from typing import Tuple
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
import warnings
from _pytest.fixtures import fixture from _pytest.fixtures import fixture
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
RE_IMPORT_ERROR_NAME = re.compile(r"^No module named (.*)$") RE_IMPORT_ERROR_NAME = re.compile(r'^No module named (.*)$')
K = TypeVar("K") K = TypeVar('K')
V = TypeVar("V") V = TypeVar('V')
@fixture @fixture
def monkeypatch() -> Generator["MonkeyPatch", None, None]: def monkeypatch() -> Generator[MonkeyPatch, None, None]:
"""A convenient fixture for monkey-patching. """A convenient fixture for monkey-patching.
The fixture provides these methods to modify objects, dictionaries, or The fixture provides these methods to modify objects, dictionaries, or
@ -60,12 +62,12 @@ def monkeypatch() -> Generator["MonkeyPatch", None, None]:
def resolve(name: str) -> object: def resolve(name: str) -> object:
# Simplified from zope.dottedname. # Simplified from zope.dottedname.
parts = name.split(".") parts = name.split('.')
used = parts.pop(0) used = parts.pop(0)
found: object = __import__(used) found: object = __import__(used)
for part in parts: for part in parts:
used += "." + part used += '.' + part
try: try:
found = getattr(found, part) found = getattr(found, part)
except AttributeError: except AttributeError:
@ -81,7 +83,7 @@ def resolve(name: str) -> object:
if expected == used: if expected == used:
raise raise
else: else:
raise ImportError(f"import error in {used}: {ex}") from ex raise ImportError(f'import error in {used}: {ex}') from ex
found = annotated_getattr(found, part, used) found = annotated_getattr(found, part, used)
return found return found
@ -91,15 +93,15 @@ def annotated_getattr(obj: object, name: str, ann: str) -> object:
obj = getattr(obj, name) obj = getattr(obj, name)
except AttributeError as e: except AttributeError as e:
raise AttributeError( raise AttributeError(
f"{type(obj).__name__!r} object at {ann} has no attribute {name!r}" f'{type(obj).__name__!r} object at {ann} has no attribute {name!r}',
) from e ) from e
return obj return obj
def derive_importpath(import_path: str, raising: bool) -> Tuple[str, object]: def derive_importpath(import_path: str, raising: bool) -> tuple[str, object]:
if not isinstance(import_path, str) or "." not in import_path: if not isinstance(import_path, str) or '.' not in import_path:
raise TypeError(f"must be absolute import path string, not {import_path!r}") raise TypeError(f'must be absolute import path string, not {import_path!r}')
module, attr = import_path.rsplit(".", 1) module, attr = import_path.rsplit('.', 1)
target = resolve(module) target = resolve(module)
if raising: if raising:
annotated_getattr(target, attr, ann=module) annotated_getattr(target, attr, ann=module)
@ -108,7 +110,7 @@ def derive_importpath(import_path: str, raising: bool) -> Tuple[str, object]:
class Notset: class Notset:
def __repr__(self) -> str: def __repr__(self) -> str:
return "<notset>" return '<notset>'
notset = Notset() notset = Notset()
@ -129,14 +131,14 @@ class MonkeyPatch:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._setattr: List[Tuple[object, str, object]] = [] self._setattr: list[tuple[object, str, object]] = []
self._setitem: List[Tuple[Mapping[Any, Any], object, object]] = [] self._setitem: list[tuple[Mapping[Any, Any], object, object]] = []
self._cwd: Optional[str] = None self._cwd: str | None = None
self._savesyspath: Optional[List[str]] = None self._savesyspath: list[str] | None = None
@classmethod @classmethod
@contextmanager @contextmanager
def context(cls) -> Generator["MonkeyPatch", None, None]: def context(cls) -> Generator[MonkeyPatch, None, None]:
"""Context manager that returns a new :class:`MonkeyPatch` object """Context manager that returns a new :class:`MonkeyPatch` object
which undoes any patching done inside the ``with`` block upon exit. which undoes any patching done inside the ``with`` block upon exit.
@ -182,8 +184,8 @@ class MonkeyPatch:
def setattr( def setattr(
self, self,
target: Union[str, object], target: str | object,
name: Union[object, str], name: object | str,
value: object = notset, value: object = notset,
raising: bool = True, raising: bool = True,
) -> None: ) -> None:
@ -228,23 +230,23 @@ class MonkeyPatch:
if isinstance(value, Notset): if isinstance(value, Notset):
if not isinstance(target, str): if not isinstance(target, str):
raise TypeError( raise TypeError(
"use setattr(target, name, value) or " 'use setattr(target, name, value) or '
"setattr(target, value) with target being a dotted " 'setattr(target, value) with target being a dotted '
"import string" 'import string',
) )
value = name value = name
name, target = derive_importpath(target, raising) name, target = derive_importpath(target, raising)
else: else:
if not isinstance(name, str): if not isinstance(name, str):
raise TypeError( raise TypeError(
"use setattr(target, name, value) with name being a string or " 'use setattr(target, name, value) with name being a string or '
"setattr(target, value) with target being a dotted " 'setattr(target, value) with target being a dotted '
"import string" 'import string',
) )
oldval = getattr(target, name, notset) oldval = getattr(target, name, notset)
if raising and oldval is notset: if raising and oldval is notset:
raise AttributeError(f"{target!r} has no attribute {name!r}") raise AttributeError(f'{target!r} has no attribute {name!r}')
# avoid class descriptors like staticmethod/classmethod # avoid class descriptors like staticmethod/classmethod
if inspect.isclass(target): if inspect.isclass(target):
@ -254,8 +256,8 @@ class MonkeyPatch:
def delattr( def delattr(
self, self,
target: Union[object, str], target: object | str,
name: Union[str, Notset] = notset, name: str | Notset = notset,
raising: bool = True, raising: bool = True,
) -> None: ) -> None:
"""Delete attribute ``name`` from ``target``. """Delete attribute ``name`` from ``target``.
@ -273,9 +275,9 @@ class MonkeyPatch:
if isinstance(name, Notset): if isinstance(name, Notset):
if not isinstance(target, str): if not isinstance(target, str):
raise TypeError( raise TypeError(
"use delattr(target, name) or " 'use delattr(target, name) or '
"delattr(target) with target being a dotted " 'delattr(target) with target being a dotted '
"import string" 'import string',
) )
name, target = derive_importpath(target, raising) name, target = derive_importpath(target, raising)
@ -310,7 +312,7 @@ class MonkeyPatch:
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict # Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
del dic[name] # type: ignore[attr-defined] del dic[name] # type: ignore[attr-defined]
def setenv(self, name: str, value: str, prepend: Optional[str] = None) -> None: def setenv(self, name: str, value: str, prepend: str | None = None) -> None:
"""Set environment variable ``name`` to ``value``. """Set environment variable ``name`` to ``value``.
If ``prepend`` is a character, read the current environment variable If ``prepend`` is a character, read the current environment variable
@ -320,8 +322,8 @@ class MonkeyPatch:
if not isinstance(value, str): if not isinstance(value, str):
warnings.warn( # type: ignore[unreachable] warnings.warn( # type: ignore[unreachable]
PytestWarning( PytestWarning(
f"Value of environment variable {name} type should be str, but got " f'Value of environment variable {name} type should be str, but got '
f"{value!r} (type: {type(value).__name__}); converted to str implicitly" f'{value!r} (type: {type(value).__name__}); converted to str implicitly',
), ),
stacklevel=2, stacklevel=2,
) )
@ -347,7 +349,7 @@ class MonkeyPatch:
# https://github.com/pypa/setuptools/blob/d8b901bc/docs/pkg_resources.txt#L162-L171 # https://github.com/pypa/setuptools/blob/d8b901bc/docs/pkg_resources.txt#L162-L171
# this is only needed when pkg_resources was already loaded by the namespace package # this is only needed when pkg_resources was already loaded by the namespace package
if "pkg_resources" in sys.modules: if 'pkg_resources' in sys.modules:
from pkg_resources import fixup_namespace_packages from pkg_resources import fixup_namespace_packages
fixup_namespace_packages(str(path)) fixup_namespace_packages(str(path))
@ -363,7 +365,7 @@ class MonkeyPatch:
invalidate_caches() invalidate_caches()
def chdir(self, path: Union[str, "os.PathLike[str]"]) -> None: def chdir(self, path: str | os.PathLike[str]) -> None:
"""Change the current working directory to the specified path. """Change the current working directory to the specified path.
:param path: :param path:

View file

@ -1,9 +1,12 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import abc import abc
from functools import cached_property
from inspect import signature
import os import os
import pathlib import pathlib
import warnings
from functools import cached_property
from inspect import signature
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from typing import Callable from typing import Callable
@ -21,11 +24,9 @@ from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
import warnings
import pluggy
import _pytest._code import _pytest._code
import pluggy
from _pytest._code import getfslineno from _pytest._code import getfslineno
from _pytest._code.code import ExceptionInfo from _pytest._code.code import ExceptionInfo
from _pytest._code.code import TerminalRepr from _pytest._code.code import TerminalRepr
@ -53,18 +54,18 @@ if TYPE_CHECKING:
from _pytest.main import Session from _pytest.main import Session
SEP = "/" SEP = '/'
tracebackcutdir = Path(_pytest.__file__).parent tracebackcutdir = Path(_pytest.__file__).parent
_T = TypeVar("_T") _T = TypeVar('_T')
def _imply_path( def _imply_path(
node_type: Type["Node"], node_type: type[Node],
path: Optional[Path], path: Path | None,
fspath: Optional[LEGACY_PATH], fspath: LEGACY_PATH | None,
) -> Path: ) -> Path:
if fspath is not None: if fspath is not None:
warnings.warn( warnings.warn(
@ -82,7 +83,7 @@ def _imply_path(
return Path(fspath) return Path(fspath)
_NodeType = TypeVar("_NodeType", bound="Node") _NodeType = TypeVar('_NodeType', bound='Node')
class NodeMeta(abc.ABCMeta): class NodeMeta(abc.ABCMeta):
@ -102,28 +103,28 @@ class NodeMeta(abc.ABCMeta):
def __call__(cls, *k, **kw) -> NoReturn: def __call__(cls, *k, **kw) -> NoReturn:
msg = ( msg = (
"Direct construction of {name} has been deprecated, please use {name}.from_parent.\n" 'Direct construction of {name} has been deprecated, please use {name}.from_parent.\n'
"See " 'See '
"https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent" 'https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent'
" for more details." ' for more details.'
).format(name=f"{cls.__module__}.{cls.__name__}") ).format(name=f'{cls.__module__}.{cls.__name__}')
fail(msg, pytrace=False) fail(msg, pytrace=False)
def _create(cls: Type[_T], *k, **kw) -> _T: def _create(cls: type[_T], *k, **kw) -> _T:
try: try:
return super().__call__(*k, **kw) # type: ignore[no-any-return,misc] return super().__call__(*k, **kw) # type: ignore[no-any-return,misc]
except TypeError: except TypeError:
sig = signature(getattr(cls, "__init__")) sig = signature(getattr(cls, '__init__'))
known_kw = {k: v for k, v in kw.items() if k in sig.parameters} known_kw = {k: v for k, v in kw.items() if k in sig.parameters}
from .warning_types import PytestDeprecationWarning from .warning_types import PytestDeprecationWarning
warnings.warn( warnings.warn(
PytestDeprecationWarning( PytestDeprecationWarning(
f"{cls} is not using a cooperative constructor and only takes {set(known_kw)}.\n" f'{cls} is not using a cooperative constructor and only takes {set(known_kw)}.\n'
"See https://docs.pytest.org/en/stable/deprecations.html" 'See https://docs.pytest.org/en/stable/deprecations.html'
"#constructors-of-custom-pytest-node-subclasses-should-take-kwargs " '#constructors-of-custom-pytest-node-subclasses-should-take-kwargs '
"for more details." 'for more details.',
) ),
) )
return super().__call__(*k, **known_kw) # type: ignore[no-any-return,misc] return super().__call__(*k, **known_kw) # type: ignore[no-any-return,misc]
@ -147,25 +148,25 @@ class Node(abc.ABC, metaclass=NodeMeta):
# Use __slots__ to make attribute access faster. # Use __slots__ to make attribute access faster.
# Note that __dict__ is still available. # Note that __dict__ is still available.
__slots__ = ( __slots__ = (
"name", 'name',
"parent", 'parent',
"config", 'config',
"session", 'session',
"path", 'path',
"_nodeid", '_nodeid',
"_store", '_store',
"__dict__", '__dict__',
) )
def __init__( def __init__(
self, self,
name: str, name: str,
parent: "Optional[Node]" = None, parent: Optional[Node] = None,
config: Optional[Config] = None, config: Config | None = None,
session: "Optional[Session]" = None, session: Optional[Session] = None,
fspath: Optional[LEGACY_PATH] = None, fspath: LEGACY_PATH | None = None,
path: Optional[Path] = None, path: Path | None = None,
nodeid: Optional[str] = None, nodeid: str | None = None,
) -> None: ) -> None:
#: A unique name within the scope of the parent node. #: A unique name within the scope of the parent node.
self.name: str = name self.name: str = name
@ -178,7 +179,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
self.config: Config = config self.config: Config = config
else: else:
if not parent: if not parent:
raise TypeError("config or parent must be provided") raise TypeError('config or parent must be provided')
self.config = parent.config self.config = parent.config
if session: if session:
@ -186,11 +187,11 @@ class Node(abc.ABC, metaclass=NodeMeta):
self.session: Session = session self.session: Session = session
else: else:
if not parent: if not parent:
raise TypeError("session or parent must be provided") raise TypeError('session or parent must be provided')
self.session = parent.session self.session = parent.session
if path is None and fspath is None: if path is None and fspath is None:
path = getattr(parent, "path", None) path = getattr(parent, 'path', None)
#: Filesystem path where this node was collected from (can be None). #: Filesystem path where this node was collected from (can be None).
self.path: pathlib.Path = _imply_path(type(self), path, fspath=fspath) self.path: pathlib.Path = _imply_path(type(self), path, fspath=fspath)
@ -199,18 +200,18 @@ class Node(abc.ABC, metaclass=NodeMeta):
self.keywords: MutableMapping[str, Any] = NodeKeywords(self) self.keywords: MutableMapping[str, Any] = NodeKeywords(self)
#: The marker objects belonging to this node. #: The marker objects belonging to this node.
self.own_markers: List[Mark] = [] self.own_markers: list[Mark] = []
#: Allow adding of extra keywords to use for matching. #: Allow adding of extra keywords to use for matching.
self.extra_keyword_matches: Set[str] = set() self.extra_keyword_matches: set[str] = set()
if nodeid is not None: if nodeid is not None:
assert "::()" not in nodeid assert '::()' not in nodeid
self._nodeid = nodeid self._nodeid = nodeid
else: else:
if not self.parent: if not self.parent:
raise TypeError("nodeid or parent must be provided") raise TypeError('nodeid or parent must be provided')
self._nodeid = self.parent.nodeid + "::" + self.name self._nodeid = self.parent.nodeid + '::' + self.name
#: A place where plugins can store information on the node for their #: A place where plugins can store information on the node for their
#: own use. #: own use.
@ -219,7 +220,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
self._store = self.stash self._store = self.stash
@classmethod @classmethod
def from_parent(cls, parent: "Node", **kw) -> "Self": def from_parent(cls, parent: Node, **kw) -> Self:
"""Public constructor for Nodes. """Public constructor for Nodes.
This indirection got introduced in order to enable removing This indirection got introduced in order to enable removing
@ -230,10 +231,10 @@ class Node(abc.ABC, metaclass=NodeMeta):
:param parent: The parent node of this Node. :param parent: The parent node of this Node.
""" """
if "config" in kw: if 'config' in kw:
raise TypeError("config is not a valid argument for from_parent") raise TypeError('config is not a valid argument for from_parent')
if "session" in kw: if 'session' in kw:
raise TypeError("session is not a valid argument for from_parent") raise TypeError('session is not a valid argument for from_parent')
return cls._create(parent=parent, **kw) return cls._create(parent=parent, **kw)
@property @property
@ -242,7 +243,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
return self.session.gethookproxy(self.path) return self.session.gethookproxy(self.path)
def __repr__(self) -> str: def __repr__(self) -> str:
return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None)) return '<{} {}>'.format(self.__class__.__name__, getattr(self, 'name', None))
def warn(self, warning: Warning) -> None: def warn(self, warning: Warning) -> None:
"""Issue a warning for this Node. """Issue a warning for this Node.
@ -268,7 +269,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
# enforce type checks here to avoid getting a generic type error later otherwise. # enforce type checks here to avoid getting a generic type error later otherwise.
if not isinstance(warning, Warning): if not isinstance(warning, Warning):
raise ValueError( raise ValueError(
f"warning must be an instance of Warning or subclass, got {warning!r}" f'warning must be an instance of Warning or subclass, got {warning!r}',
) )
path, lineno = get_fslocation_from_item(self) path, lineno = get_fslocation_from_item(self)
assert lineno is not None assert lineno is not None
@ -295,22 +296,22 @@ class Node(abc.ABC, metaclass=NodeMeta):
def teardown(self) -> None: def teardown(self) -> None:
pass pass
def iter_parents(self) -> Iterator["Node"]: def iter_parents(self) -> Iterator[Node]:
"""Iterate over all parent collectors starting from and including self """Iterate over all parent collectors starting from and including self
up to the root of the collection tree. up to the root of the collection tree.
.. versionadded:: 8.1 .. versionadded:: 8.1
""" """
parent: Optional[Node] = self parent: Node | None = self
while parent is not None: while parent is not None:
yield parent yield parent
parent = parent.parent parent = parent.parent
def listchain(self) -> List["Node"]: def listchain(self) -> list[Node]:
"""Return a list of all parent collectors starting from the root of the """Return a list of all parent collectors starting from the root of the
collection tree down to and including self.""" collection tree down to and including self."""
chain = [] chain = []
item: Optional[Node] = self item: Node | None = self
while item is not None: while item is not None:
chain.append(item) chain.append(item)
item = item.parent item = item.parent
@ -318,7 +319,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
return chain return chain
def add_marker( def add_marker(
self, marker: Union[str, MarkDecorator], append: bool = True self, marker: str | MarkDecorator, append: bool = True,
) -> None: ) -> None:
"""Dynamically add a marker object to the node. """Dynamically add a marker object to the node.
@ -334,14 +335,14 @@ class Node(abc.ABC, metaclass=NodeMeta):
elif isinstance(marker, str): elif isinstance(marker, str):
marker_ = getattr(MARK_GEN, marker) marker_ = getattr(MARK_GEN, marker)
else: else:
raise ValueError("is not a string or pytest.mark.* Marker") raise ValueError('is not a string or pytest.mark.* Marker')
self.keywords[marker_.name] = marker_ self.keywords[marker_.name] = marker_
if append: if append:
self.own_markers.append(marker_.mark) self.own_markers.append(marker_.mark)
else: else:
self.own_markers.insert(0, marker_.mark) self.own_markers.insert(0, marker_.mark)
def iter_markers(self, name: Optional[str] = None) -> Iterator[Mark]: def iter_markers(self, name: str | None = None) -> Iterator[Mark]:
"""Iterate over all markers of the node. """Iterate over all markers of the node.
:param name: If given, filter the results by the name attribute. :param name: If given, filter the results by the name attribute.
@ -350,8 +351,8 @@ class Node(abc.ABC, metaclass=NodeMeta):
return (x[1] for x in self.iter_markers_with_node(name=name)) return (x[1] for x in self.iter_markers_with_node(name=name))
def iter_markers_with_node( def iter_markers_with_node(
self, name: Optional[str] = None self, name: str | None = None,
) -> Iterator[Tuple["Node", Mark]]: ) -> Iterator[tuple[Node, Mark]]:
"""Iterate over all markers of the node. """Iterate over all markers of the node.
:param name: If given, filter the results by the name attribute. :param name: If given, filter the results by the name attribute.
@ -359,11 +360,11 @@ class Node(abc.ABC, metaclass=NodeMeta):
""" """
for node in self.iter_parents(): for node in self.iter_parents():
for mark in node.own_markers: for mark in node.own_markers:
if name is None or getattr(mark, "name", None) == name: if name is None or getattr(mark, 'name', None) == name:
yield node, mark yield node, mark
@overload @overload
def get_closest_marker(self, name: str) -> Optional[Mark]: def get_closest_marker(self, name: str) -> Mark | None:
... ...
@overload @overload
@ -371,8 +372,8 @@ class Node(abc.ABC, metaclass=NodeMeta):
... ...
def get_closest_marker( def get_closest_marker(
self, name: str, default: Optional[Mark] = None self, name: str, default: Mark | None = None,
) -> Optional[Mark]: ) -> Mark | None:
"""Return the first marker matching the name, from closest (for """Return the first marker matching the name, from closest (for
example function) to farther level (for example module level). example function) to farther level (for example module level).
@ -381,14 +382,14 @@ class Node(abc.ABC, metaclass=NodeMeta):
""" """
return next(self.iter_markers(name=name), default) return next(self.iter_markers(name=name), default)
def listextrakeywords(self) -> Set[str]: def listextrakeywords(self) -> set[str]:
"""Return a set of all extra keywords in self and any parents.""" """Return a set of all extra keywords in self and any parents."""
extra_keywords: Set[str] = set() extra_keywords: set[str] = set()
for item in self.listchain(): for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches) extra_keywords.update(item.extra_keyword_matches)
return extra_keywords return extra_keywords
def listnames(self) -> List[str]: def listnames(self) -> list[str]:
return [x.name for x in self.listchain()] return [x.name for x in self.listchain()]
def addfinalizer(self, fin: Callable[[], object]) -> None: def addfinalizer(self, fin: Callable[[], object]) -> None:
@ -400,7 +401,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
""" """
self.session._setupstate.addfinalizer(fin, self) self.session._setupstate.addfinalizer(fin, self)
def getparent(self, cls: Type[_NodeType]) -> Optional[_NodeType]: def getparent(self, cls: type[_NodeType]) -> _NodeType | None:
"""Get the closest parent node (including self) which is an instance of """Get the closest parent node (including self) which is an instance of
the given class. the given class.
@ -418,7 +419,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
def _repr_failure_py( def _repr_failure_py(
self, self,
excinfo: ExceptionInfo[BaseException], excinfo: ExceptionInfo[BaseException],
style: "Optional[_TracebackStyle]" = None, style: Optional[_TracebackStyle] = None,
) -> TerminalRepr: ) -> TerminalRepr:
from _pytest.fixtures import FixtureLookupError from _pytest.fixtures import FixtureLookupError
@ -426,26 +427,26 @@ class Node(abc.ABC, metaclass=NodeMeta):
excinfo = ExceptionInfo.from_exception(excinfo.value.cause) excinfo = ExceptionInfo.from_exception(excinfo.value.cause)
if isinstance(excinfo.value, fail.Exception): if isinstance(excinfo.value, fail.Exception):
if not excinfo.value.pytrace: if not excinfo.value.pytrace:
style = "value" style = 'value'
if isinstance(excinfo.value, FixtureLookupError): if isinstance(excinfo.value, FixtureLookupError):
return excinfo.value.formatrepr() return excinfo.value.formatrepr()
tbfilter: Union[bool, Callable[[ExceptionInfo[BaseException]], Traceback]] tbfilter: bool | Callable[[ExceptionInfo[BaseException]], Traceback]
if self.config.getoption("fulltrace", False): if self.config.getoption('fulltrace', False):
style = "long" style = 'long'
tbfilter = False tbfilter = False
else: else:
tbfilter = self._traceback_filter tbfilter = self._traceback_filter
if style == "auto": if style == 'auto':
style = "long" style = 'long'
# XXX should excinfo.getrepr record all data and toterminal() process it? # XXX should excinfo.getrepr record all data and toterminal() process it?
if style is None: if style is None:
if self.config.getoption("tbstyle", "auto") == "short": if self.config.getoption('tbstyle', 'auto') == 'short':
style = "short" style = 'short'
else: else:
style = "long" style = 'long'
if self.config.getoption("verbose", 0) > 1: if self.config.getoption('verbose', 0) > 1:
truncate_locals = False truncate_locals = False
else: else:
truncate_locals = True truncate_locals = True
@ -464,7 +465,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
return excinfo.getrepr( return excinfo.getrepr(
funcargs=True, funcargs=True,
abspath=abspath, abspath=abspath,
showlocals=self.config.getoption("showlocals", False), showlocals=self.config.getoption('showlocals', False),
style=style, style=style,
tbfilter=tbfilter, tbfilter=tbfilter,
truncate_locals=truncate_locals, truncate_locals=truncate_locals,
@ -473,8 +474,8 @@ class Node(abc.ABC, metaclass=NodeMeta):
def repr_failure( def repr_failure(
self, self,
excinfo: ExceptionInfo[BaseException], excinfo: ExceptionInfo[BaseException],
style: "Optional[_TracebackStyle]" = None, style: Optional[_TracebackStyle] = None,
) -> Union[str, TerminalRepr]: ) -> str | TerminalRepr:
"""Return a representation of a collection or test failure. """Return a representation of a collection or test failure.
.. seealso:: :ref:`non-python tests` .. seealso:: :ref:`non-python tests`
@ -484,7 +485,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
return self._repr_failure_py(excinfo, style) return self._repr_failure_py(excinfo, style)
def get_fslocation_from_item(node: "Node") -> Tuple[Union[str, Path], Optional[int]]: def get_fslocation_from_item(node: Node) -> tuple[str | Path, int | None]:
"""Try to extract the actual location from a node, depending on available attributes: """Try to extract the actual location from a node, depending on available attributes:
* "location": a pair (path, lineno) * "location": a pair (path, lineno)
@ -494,13 +495,13 @@ def get_fslocation_from_item(node: "Node") -> Tuple[Union[str, Path], Optional[i
:rtype: A tuple of (str|Path, int) with filename and 0-based line number. :rtype: A tuple of (str|Path, int) with filename and 0-based line number.
""" """
# See Item.location. # See Item.location.
location: Optional[Tuple[str, Optional[int], str]] = getattr(node, "location", None) location: tuple[str, int | None, str] | None = getattr(node, 'location', None)
if location is not None: if location is not None:
return location[:2] return location[:2]
obj = getattr(node, "obj", None) obj = getattr(node, 'obj', None)
if obj is not None: if obj is not None:
return getfslineno(obj) return getfslineno(obj)
return getattr(node, "path", "unknown location"), -1 return getattr(node, 'path', 'unknown location'), -1
class Collector(Node, abc.ABC): class Collector(Node, abc.ABC):
@ -514,34 +515,34 @@ class Collector(Node, abc.ABC):
"""An error during collection, contains a custom message.""" """An error during collection, contains a custom message."""
@abc.abstractmethod @abc.abstractmethod
def collect(self) -> Iterable[Union["Item", "Collector"]]: def collect(self) -> Iterable[Item | Collector]:
"""Collect children (items and collectors) for this collector.""" """Collect children (items and collectors) for this collector."""
raise NotImplementedError("abstract") raise NotImplementedError('abstract')
# TODO: This omits the style= parameter which breaks Liskov Substitution. # TODO: This omits the style= parameter which breaks Liskov Substitution.
def repr_failure( # type: ignore[override] def repr_failure( # type: ignore[override]
self, excinfo: ExceptionInfo[BaseException] self, excinfo: ExceptionInfo[BaseException],
) -> Union[str, TerminalRepr]: ) -> str | TerminalRepr:
"""Return a representation of a collection failure. """Return a representation of a collection failure.
:param excinfo: Exception information for the failure. :param excinfo: Exception information for the failure.
""" """
if isinstance(excinfo.value, self.CollectError) and not self.config.getoption( if isinstance(excinfo.value, self.CollectError) and not self.config.getoption(
"fulltrace", False 'fulltrace', False,
): ):
exc = excinfo.value exc = excinfo.value
return str(exc.args[0]) return str(exc.args[0])
# Respect explicit tbstyle option, but default to "short" # Respect explicit tbstyle option, but default to "short"
# (_repr_failure_py uses "long" with "fulltrace" option always). # (_repr_failure_py uses "long" with "fulltrace" option always).
tbstyle = self.config.getoption("tbstyle", "auto") tbstyle = self.config.getoption('tbstyle', 'auto')
if tbstyle == "auto": if tbstyle == 'auto':
tbstyle = "short" tbstyle = 'short'
return self._repr_failure_py(excinfo, style=tbstyle) return self._repr_failure_py(excinfo, style=tbstyle)
def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback: def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback:
if hasattr(self, "path"): if hasattr(self, 'path'):
traceback = excinfo.traceback traceback = excinfo.traceback
ntraceback = traceback.cut(path=self.path) ntraceback = traceback.cut(path=self.path)
if ntraceback == traceback: if ntraceback == traceback:
@ -550,11 +551,11 @@ class Collector(Node, abc.ABC):
return excinfo.traceback return excinfo.traceback
def _check_initialpaths_for_relpath(session: "Session", path: Path) -> Optional[str]: def _check_initialpaths_for_relpath(session: Session, path: Path) -> str | None:
for initial_path in session._initialpaths: for initial_path in session._initialpaths:
if commonpath(path, initial_path) == initial_path: if commonpath(path, initial_path) == initial_path:
rel = str(path.relative_to(initial_path)) rel = str(path.relative_to(initial_path))
return "" if rel == "." else rel return '' if rel == '.' else rel
return None return None
@ -563,14 +564,14 @@ class FSCollector(Collector, abc.ABC):
def __init__( def __init__(
self, self,
fspath: Optional[LEGACY_PATH] = None, fspath: LEGACY_PATH | None = None,
path_or_parent: Optional[Union[Path, Node]] = None, path_or_parent: Path | Node | None = None,
path: Optional[Path] = None, path: Path | None = None,
name: Optional[str] = None, name: str | None = None,
parent: Optional[Node] = None, parent: Node | None = None,
config: Optional[Config] = None, config: Config | None = None,
session: Optional["Session"] = None, session: Session | None = None,
nodeid: Optional[str] = None, nodeid: str | None = None,
) -> None: ) -> None:
if path_or_parent: if path_or_parent:
if isinstance(path_or_parent, Node): if isinstance(path_or_parent, Node):
@ -620,10 +621,10 @@ class FSCollector(Collector, abc.ABC):
cls, cls,
parent, parent,
*, *,
fspath: Optional[LEGACY_PATH] = None, fspath: LEGACY_PATH | None = None,
path: Optional[Path] = None, path: Path | None = None,
**kw, **kw,
) -> "Self": ) -> Self:
"""The public constructor.""" """The public constructor."""
return super().from_parent(parent=parent, fspath=fspath, path=path, **kw) return super().from_parent(parent=parent, fspath=fspath, path=path, **kw)
@ -665,9 +666,9 @@ class Item(Node, abc.ABC):
self, self,
name, name,
parent=None, parent=None,
config: Optional[Config] = None, config: Config | None = None,
session: Optional["Session"] = None, session: Session | None = None,
nodeid: Optional[str] = None, nodeid: str | None = None,
**kw, **kw,
) -> None: ) -> None:
# The first two arguments are intentionally passed positionally, # The first two arguments are intentionally passed positionally,
@ -682,11 +683,11 @@ class Item(Node, abc.ABC):
nodeid=nodeid, nodeid=nodeid,
**kw, **kw,
) )
self._report_sections: List[Tuple[str, str, str]] = [] self._report_sections: list[tuple[str, str, str]] = []
#: A list of tuples (name, value) that holds user defined properties #: A list of tuples (name, value) that holds user defined properties
#: for this test. #: for this test.
self.user_properties: List[Tuple[str, object]] = [] self.user_properties: list[tuple[str, object]] = []
self._check_item_and_collector_diamond_inheritance() self._check_item_and_collector_diamond_inheritance()
@ -701,21 +702,21 @@ class Item(Node, abc.ABC):
# for the same class more than once, which is not helpful. # for the same class more than once, which is not helpful.
# It is a hack, but was deemed acceptable in order to avoid # It is a hack, but was deemed acceptable in order to avoid
# flooding the user in the common case. # flooding the user in the common case.
attr_name = "_pytest_diamond_inheritance_warning_shown" attr_name = '_pytest_diamond_inheritance_warning_shown'
if getattr(cls, attr_name, False): if getattr(cls, attr_name, False):
return return
setattr(cls, attr_name, True) setattr(cls, attr_name, True)
problems = ", ".join( problems = ', '.join(
base.__name__ for base in cls.__bases__ if issubclass(base, Collector) base.__name__ for base in cls.__bases__ if issubclass(base, Collector)
) )
if problems: if problems:
warnings.warn( warnings.warn(
f"{cls.__name__} is an Item subclass and should not be a collector, " f'{cls.__name__} is an Item subclass and should not be a collector, '
f"however its bases {problems} are collectors.\n" f'however its bases {problems} are collectors.\n'
"Please split the Collectors and the Item into separate node types.\n" 'Please split the Collectors and the Item into separate node types.\n'
"Pytest Doc example: https://docs.pytest.org/en/latest/example/nonpython.html\n" 'Pytest Doc example: https://docs.pytest.org/en/latest/example/nonpython.html\n'
"example pull request on a plugin: https://github.com/asmeurer/pytest-flakes/pull/40/", 'example pull request on a plugin: https://github.com/asmeurer/pytest-flakes/pull/40/',
PytestWarning, PytestWarning,
) )
@ -727,7 +728,7 @@ class Item(Node, abc.ABC):
.. seealso:: :ref:`non-python tests` .. seealso:: :ref:`non-python tests`
""" """
raise NotImplementedError("runtest must be implemented by Item subclass") raise NotImplementedError('runtest must be implemented by Item subclass')
def add_report_section(self, when: str, key: str, content: str) -> None: def add_report_section(self, when: str, key: str, content: str) -> None:
"""Add a new report section, similar to what's done internally to add """Add a new report section, similar to what's done internally to add
@ -746,7 +747,7 @@ class Item(Node, abc.ABC):
if content: if content:
self._report_sections.append((when, key, content)) self._report_sections.append((when, key, content))
def reportinfo(self) -> Tuple[Union["os.PathLike[str]", str], Optional[int], str]: def reportinfo(self) -> tuple[os.PathLike[str] | str, int | None, str]:
"""Get location information for this item for test reports. """Get location information for this item for test reports.
Returns a tuple with three elements: Returns a tuple with three elements:
@ -757,10 +758,10 @@ class Item(Node, abc.ABC):
.. seealso:: :ref:`non-python tests` .. seealso:: :ref:`non-python tests`
""" """
return self.path, None, "" return self.path, None, ''
@cached_property @cached_property
def location(self) -> Tuple[str, Optional[int], str]: def location(self) -> tuple[str, int | None, str]:
""" """
Returns a tuple of ``(relfspath, lineno, testname)`` for this item Returns a tuple of ``(relfspath, lineno, testname)`` for this item
where ``relfspath`` is file path relative to ``config.rootpath`` where ``relfspath`` is file path relative to ``config.rootpath``

View file

@ -1,5 +1,6 @@
"""Exception classes and constants handling test outcomes as well as """Exception classes and constants handling test outcomes as well as
functions creating them.""" functions creating them."""
from __future__ import annotations
import sys import sys
from typing import Any from typing import Any
@ -16,11 +17,11 @@ class OutcomeException(BaseException):
"""OutcomeException and its subclass instances indicate and contain info """OutcomeException and its subclass instances indicate and contain info
about test and collection outcomes.""" about test and collection outcomes."""
def __init__(self, msg: Optional[str] = None, pytrace: bool = True) -> None: def __init__(self, msg: str | None = None, pytrace: bool = True) -> None:
if msg is not None and not isinstance(msg, str): if msg is not None and not isinstance(msg, str):
error_msg = ( # type: ignore[unreachable] error_msg = ( # type: ignore[unreachable]
"{} expected string as 'msg' parameter, got '{}' instead.\n" "{} expected string as 'msg' parameter, got '{}' instead.\n"
"Perhaps you meant to use a mark?" 'Perhaps you meant to use a mark?'
) )
raise TypeError(error_msg.format(type(self).__name__, type(msg).__name__)) raise TypeError(error_msg.format(type(self).__name__, type(msg).__name__))
super().__init__(msg) super().__init__(msg)
@ -30,7 +31,7 @@ class OutcomeException(BaseException):
def __repr__(self) -> str: def __repr__(self) -> str:
if self.msg is not None: if self.msg is not None:
return self.msg return self.msg
return f"<{self.__class__.__name__} instance>" return f'<{self.__class__.__name__} instance>'
__str__ = __repr__ __str__ = __repr__
@ -41,11 +42,11 @@ TEST_OUTCOME = (OutcomeException, Exception)
class Skipped(OutcomeException): class Skipped(OutcomeException):
# XXX hackish: on 3k we fake to live in the builtins # XXX hackish: on 3k we fake to live in the builtins
# in order to have Skipped exception printing shorter/nicer # in order to have Skipped exception printing shorter/nicer
__module__ = "builtins" __module__ = 'builtins'
def __init__( def __init__(
self, self,
msg: Optional[str] = None, msg: str | None = None,
pytrace: bool = True, pytrace: bool = True,
allow_module_level: bool = False, allow_module_level: bool = False,
*, *,
@ -61,14 +62,14 @@ class Skipped(OutcomeException):
class Failed(OutcomeException): class Failed(OutcomeException):
"""Raised from an explicit call to pytest.fail().""" """Raised from an explicit call to pytest.fail()."""
__module__ = "builtins" __module__ = 'builtins'
class Exit(Exception): class Exit(Exception):
"""Raised for immediate program exits (no tracebacks/summaries).""" """Raised for immediate program exits (no tracebacks/summaries)."""
def __init__( def __init__(
self, msg: str = "unknown reason", returncode: Optional[int] = None self, msg: str = 'unknown reason', returncode: int | None = None,
) -> None: ) -> None:
self.msg = msg self.msg = msg
self.returncode = returncode self.returncode = returncode
@ -78,8 +79,8 @@ class Exit(Exception):
# Elaborate hack to work around https://github.com/python/mypy/issues/2087. # Elaborate hack to work around https://github.com/python/mypy/issues/2087.
# Ideally would just be `exit.Exception = Exit` etc. # Ideally would just be `exit.Exception = Exit` etc.
_F = TypeVar("_F", bound=Callable[..., object]) _F = TypeVar('_F', bound=Callable[..., object])
_ET = TypeVar("_ET", bound=Type[BaseException]) _ET = TypeVar('_ET', bound=Type[BaseException])
class _WithException(Protocol[_F, _ET]): class _WithException(Protocol[_F, _ET]):
@ -101,8 +102,8 @@ def _with_exception(exception_type: _ET) -> Callable[[_F], _WithException[_F, _E
@_with_exception(Exit) @_with_exception(Exit)
def exit( def exit(
reason: str = "", reason: str = '',
returncode: Optional[int] = None, returncode: int | None = None,
) -> NoReturn: ) -> NoReturn:
"""Exit testing process. """Exit testing process.
@ -119,7 +120,7 @@ def exit(
@_with_exception(Skipped) @_with_exception(Skipped)
def skip( def skip(
reason: str = "", reason: str = '',
*, *,
allow_module_level: bool = False, allow_module_level: bool = False,
) -> NoReturn: ) -> NoReturn:
@ -152,7 +153,7 @@ def skip(
@_with_exception(Failed) @_with_exception(Failed)
def fail(reason: str = "", pytrace: bool = True) -> NoReturn: def fail(reason: str = '', pytrace: bool = True) -> NoReturn:
"""Explicitly fail an executing test with the given message. """Explicitly fail an executing test with the given message.
:param reason: :param reason:
@ -171,7 +172,7 @@ class XFailed(Failed):
@_with_exception(XFailed) @_with_exception(XFailed)
def xfail(reason: str = "") -> NoReturn: def xfail(reason: str = '') -> NoReturn:
"""Imperatively xfail an executing test or setup function with the given reason. """Imperatively xfail an executing test or setup function with the given reason.
This function should be called only during testing (setup, call or teardown). This function should be called only during testing (setup, call or teardown).
@ -192,7 +193,7 @@ def xfail(reason: str = "") -> NoReturn:
def importorskip( def importorskip(
modname: str, minversion: Optional[str] = None, reason: Optional[str] = None modname: str, minversion: str | None = None, reason: str | None = None,
) -> Any: ) -> Any:
"""Import and return the requested module ``modname``, or skip the """Import and return the requested module ``modname``, or skip the
current test if the module cannot be imported. current test if the module cannot be imported.
@ -216,30 +217,30 @@ def importorskip(
import warnings import warnings
__tracebackhide__ = True __tracebackhide__ = True
compile(modname, "", "eval") # to catch syntaxerrors compile(modname, '', 'eval') # to catch syntaxerrors
with warnings.catch_warnings(): with warnings.catch_warnings():
# Make sure to ignore ImportWarnings that might happen because # Make sure to ignore ImportWarnings that might happen because
# of existing directories with the same name we're trying to # of existing directories with the same name we're trying to
# import but without a __init__.py file. # import but without a __init__.py file.
warnings.simplefilter("ignore") warnings.simplefilter('ignore')
try: try:
__import__(modname) __import__(modname)
except ImportError as exc: except ImportError as exc:
if reason is None: if reason is None:
reason = f"could not import {modname!r}: {exc}" reason = f'could not import {modname!r}: {exc}'
raise Skipped(reason, allow_module_level=True) from None raise Skipped(reason, allow_module_level=True) from None
mod = sys.modules[modname] mod = sys.modules[modname]
if minversion is None: if minversion is None:
return mod return mod
verattr = getattr(mod, "__version__", None) verattr = getattr(mod, '__version__', None)
if minversion is not None: if minversion is not None:
# Imported lazily to improve start-up time. # Imported lazily to improve start-up time.
from packaging.version import Version from packaging.version import Version
if verattr is None or Version(verattr) < Version(minversion): if verattr is None or Version(verattr) < Version(minversion):
raise Skipped( raise Skipped(
f"module {modname!r} has __version__ {verattr!r}, required is: {minversion!r}", f'module {modname!r} has __version__ {verattr!r}, required is: {minversion!r}',
allow_module_level=True, allow_module_level=True,
) )
return mod return mod

View file

@ -1,50 +1,52 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Submit failure or test session information to a pastebin service.""" """Submit failure or test session information to a pastebin service."""
from io import StringIO from __future__ import annotations
import tempfile import tempfile
from io import StringIO
from typing import IO from typing import IO
from typing import Union from typing import Union
import pytest
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import create_terminal_writer from _pytest.config import create_terminal_writer
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
import pytest
pastebinfile_key = StashKey[IO[bytes]]() pastebinfile_key = StashKey[IO[bytes]]()
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting") group = parser.getgroup('terminal reporting')
group._addoption( group._addoption(
"--pastebin", '--pastebin',
metavar="mode", metavar='mode',
action="store", action='store',
dest="pastebin", dest='pastebin',
default=None, default=None,
choices=["failed", "all"], choices=['failed', 'all'],
help="Send failed|all info to bpaste.net pastebin service", help='Send failed|all info to bpaste.net pastebin service',
) )
@pytest.hookimpl(trylast=True) @pytest.hookimpl(trylast=True)
def pytest_configure(config: Config) -> None: def pytest_configure(config: Config) -> None:
if config.option.pastebin == "all": if config.option.pastebin == 'all':
tr = config.pluginmanager.getplugin("terminalreporter") tr = config.pluginmanager.getplugin('terminalreporter')
# If no terminal reporter plugin is present, nothing we can do here; # If no terminal reporter plugin is present, nothing we can do here;
# this can happen when this function executes in a worker node # this can happen when this function executes in a worker node
# when using pytest-xdist, for example. # when using pytest-xdist, for example.
if tr is not None: if tr is not None:
# pastebin file will be UTF-8 encoded binary file. # pastebin file will be UTF-8 encoded binary file.
config.stash[pastebinfile_key] = tempfile.TemporaryFile("w+b") config.stash[pastebinfile_key] = tempfile.TemporaryFile('w+b')
oldwrite = tr._tw.write oldwrite = tr._tw.write
def tee_write(s, **kwargs): def tee_write(s, **kwargs):
oldwrite(s, **kwargs) oldwrite(s, **kwargs)
if isinstance(s, str): if isinstance(s, str):
s = s.encode("utf-8") s = s.encode('utf-8')
config.stash[pastebinfile_key].write(s) config.stash[pastebinfile_key].write(s)
tr._tw.write = tee_write tr._tw.write = tee_write
@ -59,15 +61,15 @@ def pytest_unconfigure(config: Config) -> None:
pastebinfile.close() pastebinfile.close()
del config.stash[pastebinfile_key] del config.stash[pastebinfile_key]
# Undo our patching in the terminal reporter. # Undo our patching in the terminal reporter.
tr = config.pluginmanager.getplugin("terminalreporter") tr = config.pluginmanager.getplugin('terminalreporter')
del tr._tw.__dict__["write"] del tr._tw.__dict__['write']
# Write summary. # Write summary.
tr.write_sep("=", "Sending information to Paste Service") tr.write_sep('=', 'Sending information to Paste Service')
pastebinurl = create_new_paste(sessionlog) pastebinurl = create_new_paste(sessionlog)
tr.write_line("pastebin session-log: %s\n" % pastebinurl) tr.write_line('pastebin session-log: %s\n' % pastebinurl)
def create_new_paste(contents: Union[str, bytes]) -> str: def create_new_paste(contents: str | bytes) -> str:
"""Create a new paste using the bpaste.net service. """Create a new paste using the bpaste.net service.
:contents: Paste contents string. :contents: Paste contents string.
@ -77,27 +79,27 @@ def create_new_paste(contents: Union[str, bytes]) -> str:
from urllib.parse import urlencode from urllib.parse import urlencode
from urllib.request import urlopen from urllib.request import urlopen
params = {"code": contents, "lexer": "text", "expiry": "1week"} params = {'code': contents, 'lexer': 'text', 'expiry': '1week'}
url = "https://bpa.st" url = 'https://bpa.st'
try: try:
response: str = ( response: str = (
urlopen(url, data=urlencode(params).encode("ascii")).read().decode("utf-8") urlopen(url, data=urlencode(params).encode('ascii')).read().decode('utf-8')
) )
except OSError as exc_info: # urllib errors except OSError as exc_info: # urllib errors
return "bad response: %s" % exc_info return 'bad response: %s' % exc_info
m = re.search(r'href="/raw/(\w+)"', response) m = re.search(r'href="/raw/(\w+)"', response)
if m: if m:
return f"{url}/show/{m.group(1)}" return f'{url}/show/{m.group(1)}'
else: else:
return "bad response: invalid format ('" + response + "')" return "bad response: invalid format ('" + response + "')"
def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None: def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
if terminalreporter.config.option.pastebin != "failed": if terminalreporter.config.option.pastebin != 'failed':
return return
if "failed" in terminalreporter.stats: if 'failed' in terminalreporter.stats:
terminalreporter.write_sep("=", "Sending information to Paste Service") terminalreporter.write_sep('=', 'Sending information to Paste Service')
for rep in terminalreporter.stats["failed"]: for rep in terminalreporter.stats['failed']:
try: try:
msg = rep.longrepr.reprtraceback.reprentries[-1].reprfileloc msg = rep.longrepr.reprtraceback.reprentries[-1].reprfileloc
except AttributeError: except AttributeError:
@ -108,4 +110,4 @@ def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
s = file.getvalue() s = file.getvalue()
assert len(s) assert len(s)
pastebinurl = create_new_paste(s) pastebinurl = create_new_paste(s)
terminalreporter.write_line(f"{msg} --> {pastebinurl}") terminalreporter.write_line(f'{msg} --> {pastebinurl}')

View file

@ -1,16 +1,23 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import atexit import atexit
import contextlib import contextlib
import fnmatch
import importlib.util
import itertools
import os
import shutil
import sys
import types
import uuid
import warnings
from enum import Enum from enum import Enum
from errno import EBADF from errno import EBADF
from errno import ELOOP from errno import ELOOP
from errno import ENOENT from errno import ENOENT
from errno import ENOTDIR from errno import ENOTDIR
import fnmatch
from functools import partial from functools import partial
import importlib.util
import itertools
import os
from os.path import expanduser from os.path import expanduser
from os.path import expandvars from os.path import expandvars
from os.path import isabs from os.path import isabs
@ -18,9 +25,6 @@ from os.path import sep
from pathlib import Path from pathlib import Path
from pathlib import PurePath from pathlib import PurePath
from posixpath import sep as posix_sep from posixpath import sep as posix_sep
import shutil
import sys
import types
from types import ModuleType from types import ModuleType
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
@ -33,8 +37,6 @@ from typing import Tuple
from typing import Type from typing import Type
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
import uuid
import warnings
from _pytest.compat import assert_never from _pytest.compat import assert_never
from _pytest.outcomes import skip from _pytest.outcomes import skip
@ -44,7 +46,7 @@ from _pytest.warning_types import PytestWarning
LOCK_TIMEOUT = 60 * 60 * 24 * 3 LOCK_TIMEOUT = 60 * 60 * 24 * 3
_AnyPurePath = TypeVar("_AnyPurePath", bound=PurePath) _AnyPurePath = TypeVar('_AnyPurePath', bound=PurePath)
# The following function, variables and comments were # The following function, variables and comments were
# copied from cpython 3.9 Lib/pathlib.py file. # copied from cpython 3.9 Lib/pathlib.py file.
@ -60,22 +62,22 @@ _IGNORED_WINERRORS = (
def _ignore_error(exception): def _ignore_error(exception):
return ( return (
getattr(exception, "errno", None) in _IGNORED_ERRORS getattr(exception, 'errno', None) in _IGNORED_ERRORS or
or getattr(exception, "winerror", None) in _IGNORED_WINERRORS getattr(exception, 'winerror', None) in _IGNORED_WINERRORS
) )
def get_lock_path(path: _AnyPurePath) -> _AnyPurePath: def get_lock_path(path: _AnyPurePath) -> _AnyPurePath:
return path.joinpath(".lock") return path.joinpath('.lock')
def on_rm_rf_error( def on_rm_rf_error(
func, func,
path: str, path: str,
excinfo: Union[ excinfo: (
BaseException, BaseException |
Tuple[Type[BaseException], BaseException, Optional[types.TracebackType]], tuple[type[BaseException], BaseException, types.TracebackType | None]
], ),
*, *,
start_path: Path, start_path: Path,
) -> bool: ) -> bool:
@ -95,7 +97,7 @@ def on_rm_rf_error(
if not isinstance(exc, PermissionError): if not isinstance(exc, PermissionError):
warnings.warn( warnings.warn(
PytestWarning(f"(rm_rf) error removing {path}\n{type(exc)}: {exc}") PytestWarning(f'(rm_rf) error removing {path}\n{type(exc)}: {exc}'),
) )
return False return False
@ -103,8 +105,8 @@ def on_rm_rf_error(
if func not in (os.open,): if func not in (os.open,):
warnings.warn( warnings.warn(
PytestWarning( PytestWarning(
f"(rm_rf) unknown function {func} when removing {path}:\n{type(exc)}: {exc}" f'(rm_rf) unknown function {func} when removing {path}:\n{type(exc)}: {exc}',
) ),
) )
return False return False
@ -142,7 +144,7 @@ def ensure_extended_length_path(path: Path) -> Path:
On Windows, this function returns the extended-length absolute version of path. On Windows, this function returns the extended-length absolute version of path.
On other platforms it returns path unchanged. On other platforms it returns path unchanged.
""" """
if sys.platform.startswith("win32"): if sys.platform.startswith('win32'):
path = path.resolve() path = path.resolve()
path = Path(get_extended_length_path_str(str(path))) path = Path(get_extended_length_path_str(str(path)))
return path return path
@ -150,12 +152,12 @@ def ensure_extended_length_path(path: Path) -> Path:
def get_extended_length_path_str(path: str) -> str: def get_extended_length_path_str(path: str) -> str:
"""Convert a path to a Windows extended length path.""" """Convert a path to a Windows extended length path."""
long_path_prefix = "\\\\?\\" long_path_prefix = '\\\\?\\'
unc_long_path_prefix = "\\\\?\\UNC\\" unc_long_path_prefix = '\\\\?\\UNC\\'
if path.startswith((long_path_prefix, unc_long_path_prefix)): if path.startswith((long_path_prefix, unc_long_path_prefix)):
return path return path
# UNC # UNC
if path.startswith("\\\\"): if path.startswith('\\\\'):
return unc_long_path_prefix + path[2:] return unc_long_path_prefix + path[2:]
return long_path_prefix + path return long_path_prefix + path
@ -171,7 +173,7 @@ def rm_rf(path: Path) -> None:
shutil.rmtree(str(path), onerror=onerror) shutil.rmtree(str(path), onerror=onerror)
def find_prefixed(root: Path, prefix: str) -> Iterator["os.DirEntry[str]"]: def find_prefixed(root: Path, prefix: str) -> Iterator[os.DirEntry[str]]:
"""Find all elements in root that begin with the prefix, case insensitive.""" """Find all elements in root that begin with the prefix, case insensitive."""
l_prefix = prefix.lower() l_prefix = prefix.lower()
for x in os.scandir(root): for x in os.scandir(root):
@ -179,7 +181,7 @@ def find_prefixed(root: Path, prefix: str) -> Iterator["os.DirEntry[str]"]:
yield x yield x
def extract_suffixes(iter: Iterable["os.DirEntry[str]"], prefix: str) -> Iterator[str]: def extract_suffixes(iter: Iterable[os.DirEntry[str]], prefix: str) -> Iterator[str]:
"""Return the parts of the paths following the prefix. """Return the parts of the paths following the prefix.
:param iter: Iterator over path names. :param iter: Iterator over path names.
@ -204,7 +206,7 @@ def parse_num(maybe_num) -> int:
def _force_symlink( def _force_symlink(
root: Path, target: Union[str, PurePath], link_to: Union[str, Path] root: Path, target: str | PurePath, link_to: str | Path,
) -> None: ) -> None:
"""Helper to create the current symlink. """Helper to create the current symlink.
@ -231,18 +233,18 @@ def make_numbered_dir(root: Path, prefix: str, mode: int = 0o700) -> Path:
# try up to 10 times to create the folder # try up to 10 times to create the folder
max_existing = max(map(parse_num, find_suffixes(root, prefix)), default=-1) max_existing = max(map(parse_num, find_suffixes(root, prefix)), default=-1)
new_number = max_existing + 1 new_number = max_existing + 1
new_path = root.joinpath(f"{prefix}{new_number}") new_path = root.joinpath(f'{prefix}{new_number}')
try: try:
new_path.mkdir(mode=mode) new_path.mkdir(mode=mode)
except Exception: except Exception:
pass pass
else: else:
_force_symlink(root, prefix + "current", new_path) _force_symlink(root, prefix + 'current', new_path)
return new_path return new_path
else: else:
raise OSError( raise OSError(
"could not create numbered dir with prefix " 'could not create numbered dir with prefix '
f"{prefix} in {root} after 10 tries" f'{prefix} in {root} after 10 tries',
) )
@ -252,14 +254,14 @@ def create_cleanup_lock(p: Path) -> Path:
try: try:
fd = os.open(str(lock_path), os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644) fd = os.open(str(lock_path), os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
except FileExistsError as e: except FileExistsError as e:
raise OSError(f"cannot create lockfile in {p}") from e raise OSError(f'cannot create lockfile in {p}') from e
else: else:
pid = os.getpid() pid = os.getpid()
spid = str(pid).encode() spid = str(pid).encode()
os.write(fd, spid) os.write(fd, spid)
os.close(fd) os.close(fd)
if not lock_path.is_file(): if not lock_path.is_file():
raise OSError("lock path got renamed after successful creation") raise OSError('lock path got renamed after successful creation')
return lock_path return lock_path
@ -289,7 +291,7 @@ def maybe_delete_a_numbered_dir(path: Path) -> None:
lock_path = create_cleanup_lock(path) lock_path = create_cleanup_lock(path)
parent = path.parent parent = path.parent
garbage = parent.joinpath(f"garbage-{uuid.uuid4()}") garbage = parent.joinpath(f'garbage-{uuid.uuid4()}')
path.rename(garbage) path.rename(garbage)
rm_rf(garbage) rm_rf(garbage)
except OSError: except OSError:
@ -362,14 +364,14 @@ def cleanup_dead_symlinks(root: Path):
def cleanup_numbered_dir( def cleanup_numbered_dir(
root: Path, prefix: str, keep: int, consider_lock_dead_if_created_before: float root: Path, prefix: str, keep: int, consider_lock_dead_if_created_before: float,
) -> None: ) -> None:
"""Cleanup for lock driven numbered directories.""" """Cleanup for lock driven numbered directories."""
if not root.exists(): if not root.exists():
return return
for path in cleanup_candidates(root, prefix, keep): for path in cleanup_candidates(root, prefix, keep):
try_cleanup(path, consider_lock_dead_if_created_before) try_cleanup(path, consider_lock_dead_if_created_before)
for path in root.glob("garbage-*"): for path in root.glob('garbage-*'):
try_cleanup(path, consider_lock_dead_if_created_before) try_cleanup(path, consider_lock_dead_if_created_before)
cleanup_dead_symlinks(root) cleanup_dead_symlinks(root)
@ -417,7 +419,7 @@ def resolve_from_str(input: str, rootpath: Path) -> Path:
return rootpath.joinpath(input) return rootpath.joinpath(input)
def fnmatch_ex(pattern: str, path: Union[str, "os.PathLike[str]"]) -> bool: def fnmatch_ex(pattern: str, path: str | os.PathLike[str]) -> bool:
"""A port of FNMatcher from py.path.common which works with PurePath() instances. """A port of FNMatcher from py.path.common which works with PurePath() instances.
The difference between this algorithm and PurePath.match() is that the The difference between this algorithm and PurePath.match() is that the
@ -436,7 +438,7 @@ def fnmatch_ex(pattern: str, path: Union[str, "os.PathLike[str]"]) -> bool:
* https://bugs.python.org/issue34731 * https://bugs.python.org/issue34731
""" """
path = PurePath(path) path = PurePath(path)
iswin32 = sys.platform.startswith("win") iswin32 = sys.platform.startswith('win')
if iswin32 and sep not in pattern and posix_sep in pattern: if iswin32 and sep not in pattern and posix_sep in pattern:
# Running on Windows, the pattern has no Windows path separators, # Running on Windows, the pattern has no Windows path separators,
@ -449,11 +451,11 @@ def fnmatch_ex(pattern: str, path: Union[str, "os.PathLike[str]"]) -> bool:
else: else:
name = str(path) name = str(path)
if path.is_absolute() and not os.path.isabs(pattern): if path.is_absolute() and not os.path.isabs(pattern):
pattern = f"*{os.sep}{pattern}" pattern = f'*{os.sep}{pattern}'
return fnmatch.fnmatch(name, pattern) return fnmatch.fnmatch(name, pattern)
def parts(s: str) -> Set[str]: def parts(s: str) -> set[str]:
parts = s.split(sep) parts = s.split(sep)
return {sep.join(parts[: i + 1]) or sep for i in range(len(parts))} return {sep.join(parts[: i + 1]) or sep for i in range(len(parts))}
@ -463,15 +465,15 @@ def symlink_or_skip(src, dst, **kwargs):
try: try:
os.symlink(str(src), str(dst), **kwargs) os.symlink(str(src), str(dst), **kwargs)
except OSError as e: except OSError as e:
skip(f"symlinks not supported: {e}") skip(f'symlinks not supported: {e}')
class ImportMode(Enum): class ImportMode(Enum):
"""Possible values for `mode` parameter of `import_path`.""" """Possible values for `mode` parameter of `import_path`."""
prepend = "prepend" prepend = 'prepend'
append = "append" append = 'append'
importlib = "importlib" importlib = 'importlib'
class ImportPathMismatchError(ImportError): class ImportPathMismatchError(ImportError):
@ -484,9 +486,9 @@ class ImportPathMismatchError(ImportError):
def import_path( def import_path(
path: Union[str, "os.PathLike[str]"], path: str | os.PathLike[str],
*, *,
mode: Union[str, ImportMode] = ImportMode.prepend, mode: str | ImportMode = ImportMode.prepend,
root: Path, root: Path,
consider_namespace_packages: bool, consider_namespace_packages: bool,
) -> ModuleType: ) -> ModuleType:
@ -534,7 +536,7 @@ def import_path(
# without touching sys.path. # without touching sys.path.
try: try:
pkg_root, module_name = resolve_pkg_root_and_module_name( pkg_root, module_name = resolve_pkg_root_and_module_name(
path, consider_namespace_packages=consider_namespace_packages path, consider_namespace_packages=consider_namespace_packages,
) )
except CouldNotResolvePathError: except CouldNotResolvePathError:
pass pass
@ -544,7 +546,7 @@ def import_path(
return sys.modules[module_name] return sys.modules[module_name]
mod = _import_module_using_spec( mod = _import_module_using_spec(
module_name, path, pkg_root, insert_modules=False module_name, path, pkg_root, insert_modules=False,
) )
if mod is not None: if mod is not None:
return mod return mod
@ -556,7 +558,7 @@ def import_path(
return sys.modules[module_name] return sys.modules[module_name]
mod = _import_module_using_spec( mod = _import_module_using_spec(
module_name, path, path.parent, insert_modules=True module_name, path, path.parent, insert_modules=True,
) )
if mod is None: if mod is None:
raise ImportError(f"Can't find module {module_name} at location {path}") raise ImportError(f"Can't find module {module_name} at location {path}")
@ -564,7 +566,7 @@ def import_path(
try: try:
pkg_root, module_name = resolve_pkg_root_and_module_name( pkg_root, module_name = resolve_pkg_root_and_module_name(
path, consider_namespace_packages=consider_namespace_packages path, consider_namespace_packages=consider_namespace_packages,
) )
except CouldNotResolvePathError: except CouldNotResolvePathError:
pkg_root, module_name = path.parent, path.stem pkg_root, module_name = path.parent, path.stem
@ -584,19 +586,19 @@ def import_path(
importlib.import_module(module_name) importlib.import_module(module_name)
mod = sys.modules[module_name] mod = sys.modules[module_name]
if path.name == "__init__.py": if path.name == '__init__.py':
return mod return mod
ignore = os.environ.get("PY_IGNORE_IMPORTMISMATCH", "") ignore = os.environ.get('PY_IGNORE_IMPORTMISMATCH', '')
if ignore != "1": if ignore != '1':
module_file = mod.__file__ module_file = mod.__file__
if module_file is None: if module_file is None:
raise ImportPathMismatchError(module_name, module_file, path) raise ImportPathMismatchError(module_name, module_file, path)
if module_file.endswith((".pyc", ".pyo")): if module_file.endswith(('.pyc', '.pyo')):
module_file = module_file[:-1] module_file = module_file[:-1]
if module_file.endswith(os.sep + "__init__.py"): if module_file.endswith(os.sep + '__init__.py'):
module_file = module_file[: -(len(os.sep + "__init__.py"))] module_file = module_file[: -(len(os.sep + '__init__.py'))]
try: try:
is_same = _is_same(str(path), module_file) is_same = _is_same(str(path), module_file)
@ -610,8 +612,8 @@ def import_path(
def _import_module_using_spec( def _import_module_using_spec(
module_name: str, module_path: Path, module_location: Path, *, insert_modules: bool module_name: str, module_path: Path, module_location: Path, *, insert_modules: bool,
) -> Optional[ModuleType]: ) -> ModuleType | None:
""" """
Tries to import a module by its canonical name, path to the .py file, and its Tries to import a module by its canonical name, path to the .py file, and its
parent location. parent location.
@ -645,7 +647,7 @@ def _import_module_using_spec(
# Implement a special _is_same function on Windows which returns True if the two filenames # Implement a special _is_same function on Windows which returns True if the two filenames
# compare equal, to circumvent os.path.samefile returning False for mounts in UNC (#7678). # compare equal, to circumvent os.path.samefile returning False for mounts in UNC (#7678).
if sys.platform.startswith("win"): if sys.platform.startswith('win'):
def _is_same(f1: str, f2: str) -> bool: def _is_same(f1: str, f2: str) -> bool:
return Path(f1) == Path(f2) or os.path.samefile(f1, f2) return Path(f1) == Path(f2) or os.path.samefile(f1, f2)
@ -663,7 +665,7 @@ def module_name_from_path(path: Path, root: Path) -> str:
For example: path="projects/src/tests/test_foo.py" and root="/projects", the For example: path="projects/src/tests/test_foo.py" and root="/projects", the
resulting module name will be "src.tests.test_foo". resulting module name will be "src.tests.test_foo".
""" """
path = path.with_suffix("") path = path.with_suffix('')
try: try:
relative_path = path.relative_to(root) relative_path = path.relative_to(root)
except ValueError: except ValueError:
@ -676,18 +678,18 @@ def module_name_from_path(path: Path, root: Path) -> str:
# Module name for packages do not contain the __init__ file, unless # Module name for packages do not contain the __init__ file, unless
# the `__init__.py` file is at the root. # the `__init__.py` file is at the root.
if len(path_parts) >= 2 and path_parts[-1] == "__init__": if len(path_parts) >= 2 and path_parts[-1] == '__init__':
path_parts = path_parts[:-1] path_parts = path_parts[:-1]
# Module names cannot contain ".", normalize them to "_". This prevents # Module names cannot contain ".", normalize them to "_". This prevents
# a directory having a "." in the name (".env.310" for example) causing extra intermediate modules. # a directory having a "." in the name (".env.310" for example) causing extra intermediate modules.
# Also, important to replace "." at the start of paths, as those are considered relative imports. # Also, important to replace "." at the start of paths, as those are considered relative imports.
path_parts = tuple(x.replace(".", "_") for x in path_parts) path_parts = tuple(x.replace('.', '_') for x in path_parts)
return ".".join(path_parts) return '.'.join(path_parts)
def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) -> None: def insert_missing_modules(modules: dict[str, ModuleType], module_name: str) -> None:
""" """
Used by ``import_path`` to create intermediate modules when using mode=importlib. Used by ``import_path`` to create intermediate modules when using mode=importlib.
@ -695,10 +697,10 @@ def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) ->
to create empty modules "src" and "src.tests" after inserting "src.tests.test_foo", to create empty modules "src" and "src.tests" after inserting "src.tests.test_foo",
otherwise "src.tests.test_foo" is not importable by ``__import__``. otherwise "src.tests.test_foo" is not importable by ``__import__``.
""" """
module_parts = module_name.split(".") module_parts = module_name.split('.')
child_module: Union[ModuleType, None] = None child_module: ModuleType | None = None
module: Union[ModuleType, None] = None module: ModuleType | None = None
child_name: str = "" child_name: str = ''
while module_name: while module_name:
if module_name not in modules: if module_name not in modules:
try: try:
@ -723,12 +725,12 @@ def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) ->
setattr(module, child_name, child_module) setattr(module, child_name, child_module)
modules[module_name] = module modules[module_name] = module
# Keep track of the child module while moving up the tree. # Keep track of the child module while moving up the tree.
child_module, child_name = module, module_name.rpartition(".")[-1] child_module, child_name = module, module_name.rpartition('.')[-1]
module_parts.pop(-1) module_parts.pop(-1)
module_name = ".".join(module_parts) module_name = '.'.join(module_parts)
def resolve_package_path(path: Path) -> Optional[Path]: def resolve_package_path(path: Path) -> Path | None:
"""Return the Python package path by looking for the last """Return the Python package path by looking for the last
directory upwards which still contains an __init__.py. directory upwards which still contains an __init__.py.
@ -737,7 +739,7 @@ def resolve_package_path(path: Path) -> Optional[Path]:
result = None result = None
for parent in itertools.chain((path,), path.parents): for parent in itertools.chain((path,), path.parents):
if parent.is_dir(): if parent.is_dir():
if not (parent / "__init__.py").is_file(): if not (parent / '__init__.py').is_file():
break break
if not parent.name.isidentifier(): if not parent.name.isidentifier():
break break
@ -746,8 +748,8 @@ def resolve_package_path(path: Path) -> Optional[Path]:
def resolve_pkg_root_and_module_name( def resolve_pkg_root_and_module_name(
path: Path, *, consider_namespace_packages: bool = False path: Path, *, consider_namespace_packages: bool = False,
) -> Tuple[Path, str]: ) -> tuple[Path, str]:
""" """
Return the path to the directory of the root package that contains the Return the path to the directory of the root package that contains the
given Python file, and its module name: given Python file, and its module name:
@ -779,20 +781,20 @@ def resolve_pkg_root_and_module_name(
for parent in pkg_root.parents: for parent in pkg_root.parents:
# If any of the parent paths has a __init__.py, it means it is not # If any of the parent paths has a __init__.py, it means it is not
# a namespace package (see the docs linked above). # a namespace package (see the docs linked above).
if (parent / "__init__.py").is_file(): if (parent / '__init__.py').is_file():
break break
if str(parent) in sys.path: if str(parent) in sys.path:
# Point the pkg_root to the root of the namespace package. # Point the pkg_root to the root of the namespace package.
pkg_root = parent pkg_root = parent
break break
names = list(path.with_suffix("").relative_to(pkg_root).parts) names = list(path.with_suffix('').relative_to(pkg_root).parts)
if names[-1] == "__init__": if names[-1] == '__init__':
names.pop() names.pop()
module_name = ".".join(names) module_name = '.'.join(names)
return pkg_root, module_name return pkg_root, module_name
raise CouldNotResolvePathError(f"Could not resolve for {path}") raise CouldNotResolvePathError(f'Could not resolve for {path}')
class CouldNotResolvePathError(Exception): class CouldNotResolvePathError(Exception):
@ -800,9 +802,9 @@ class CouldNotResolvePathError(Exception):
def scandir( def scandir(
path: Union[str, "os.PathLike[str]"], path: str | os.PathLike[str],
sort_key: Callable[["os.DirEntry[str]"], object] = lambda entry: entry.name, sort_key: Callable[[os.DirEntry[str]], object] = lambda entry: entry.name,
) -> List["os.DirEntry[str]"]: ) -> list[os.DirEntry[str]]:
"""Scan a directory recursively, in breadth-first order. """Scan a directory recursively, in breadth-first order.
The returned entries are sorted according to the given key. The returned entries are sorted according to the given key.
@ -825,8 +827,8 @@ def scandir(
def visit( def visit(
path: Union[str, "os.PathLike[str]"], recurse: Callable[["os.DirEntry[str]"], bool] path: str | os.PathLike[str], recurse: Callable[[os.DirEntry[str]], bool],
) -> Iterator["os.DirEntry[str]"]: ) -> Iterator[os.DirEntry[str]]:
"""Walk a directory recursively, in breadth-first order. """Walk a directory recursively, in breadth-first order.
The `recurse` predicate determines whether a directory is recursed. The `recurse` predicate determines whether a directory is recursed.
@ -840,7 +842,7 @@ def visit(
yield from visit(entry.path, recurse) yield from visit(entry.path, recurse)
def absolutepath(path: Union[Path, str]) -> Path: def absolutepath(path: Path | str) -> Path:
"""Convert a path to an absolute path using os.path.abspath. """Convert a path to an absolute path using os.path.abspath.
Prefer this over Path.resolve() (see #6523). Prefer this over Path.resolve() (see #6523).
@ -849,7 +851,7 @@ def absolutepath(path: Union[Path, str]) -> Path:
return Path(os.path.abspath(str(path))) return Path(os.path.abspath(str(path)))
def commonpath(path1: Path, path2: Path) -> Optional[Path]: def commonpath(path1: Path, path2: Path) -> Path | None:
"""Return the common part shared with the other path, or None if there is """Return the common part shared with the other path, or None if there is
no common part. no common part.

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,10 @@
"""Helper plugin for pytester; should not be loaded on its own.""" """Helper plugin for pytester; should not be loaded on its own."""
# This plugin contains assertions used by pytester. pytester cannot # This plugin contains assertions used by pytester. pytester cannot
# contain them itself, since it is imported by the `pytest` module, # contain them itself, since it is imported by the `pytest` module,
# hence cannot be subject to assertion rewriting, which requires a # hence cannot be subject to assertion rewriting, which requires a
# module to not be already imported. # module to not be already imported.
from __future__ import annotations
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from typing import Sequence from typing import Sequence
@ -15,10 +16,10 @@ from _pytest.reports import TestReport
def assertoutcome( def assertoutcome(
outcomes: Tuple[ outcomes: tuple[
Sequence[TestReport], Sequence[TestReport],
Sequence[Union[CollectReport, TestReport]], Sequence[CollectReport | TestReport],
Sequence[Union[CollectReport, TestReport]], Sequence[CollectReport | TestReport],
], ],
passed: int = 0, passed: int = 0,
skipped: int = 0, skipped: int = 0,
@ -28,49 +29,49 @@ def assertoutcome(
realpassed, realskipped, realfailed = outcomes realpassed, realskipped, realfailed = outcomes
obtained = { obtained = {
"passed": len(realpassed), 'passed': len(realpassed),
"skipped": len(realskipped), 'skipped': len(realskipped),
"failed": len(realfailed), 'failed': len(realfailed),
} }
expected = {"passed": passed, "skipped": skipped, "failed": failed} expected = {'passed': passed, 'skipped': skipped, 'failed': failed}
assert obtained == expected, outcomes assert obtained == expected, outcomes
def assert_outcomes( def assert_outcomes(
outcomes: Dict[str, int], outcomes: dict[str, int],
passed: int = 0, passed: int = 0,
skipped: int = 0, skipped: int = 0,
failed: int = 0, failed: int = 0,
errors: int = 0, errors: int = 0,
xpassed: int = 0, xpassed: int = 0,
xfailed: int = 0, xfailed: int = 0,
warnings: Optional[int] = None, warnings: int | None = None,
deselected: Optional[int] = None, deselected: int | None = None,
) -> None: ) -> None:
"""Assert that the specified outcomes appear with the respective """Assert that the specified outcomes appear with the respective
numbers (0 means it didn't occur) in the text output from a test run.""" numbers (0 means it didn't occur) in the text output from a test run."""
__tracebackhide__ = True __tracebackhide__ = True
obtained = { obtained = {
"passed": outcomes.get("passed", 0), 'passed': outcomes.get('passed', 0),
"skipped": outcomes.get("skipped", 0), 'skipped': outcomes.get('skipped', 0),
"failed": outcomes.get("failed", 0), 'failed': outcomes.get('failed', 0),
"errors": outcomes.get("errors", 0), 'errors': outcomes.get('errors', 0),
"xpassed": outcomes.get("xpassed", 0), 'xpassed': outcomes.get('xpassed', 0),
"xfailed": outcomes.get("xfailed", 0), 'xfailed': outcomes.get('xfailed', 0),
} }
expected = { expected = {
"passed": passed, 'passed': passed,
"skipped": skipped, 'skipped': skipped,
"failed": failed, 'failed': failed,
"errors": errors, 'errors': errors,
"xpassed": xpassed, 'xpassed': xpassed,
"xfailed": xfailed, 'xfailed': xfailed,
} }
if warnings is not None: if warnings is not None:
obtained["warnings"] = outcomes.get("warnings", 0) obtained['warnings'] = outcomes.get('warnings', 0)
expected["warnings"] = warnings expected['warnings'] = warnings
if deselected is not None: if deselected is not None:
obtained["deselected"] = outcomes.get("deselected", 0) obtained['deselected'] = outcomes.get('deselected', 0)
expected["deselected"] = deselected expected['deselected'] = deselected
assert obtained == expected assert obtained == expected

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,12 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import math
import pprint
from collections.abc import Collection from collections.abc import Collection
from collections.abc import Sized from collections.abc import Sized
from decimal import Decimal from decimal import Decimal
import math
from numbers import Complex from numbers import Complex
import pprint
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
@ -33,25 +35,25 @@ if TYPE_CHECKING:
def _compare_approx( def _compare_approx(
full_object: object, full_object: object,
message_data: Sequence[Tuple[str, str, str]], message_data: Sequence[tuple[str, str, str]],
number_of_elements: int, number_of_elements: int,
different_ids: Sequence[object], different_ids: Sequence[object],
max_abs_diff: float, max_abs_diff: float,
max_rel_diff: float, max_rel_diff: float,
) -> List[str]: ) -> list[str]:
message_list = list(message_data) message_list = list(message_data)
message_list.insert(0, ("Index", "Obtained", "Expected")) message_list.insert(0, ('Index', 'Obtained', 'Expected'))
max_sizes = [0, 0, 0] max_sizes = [0, 0, 0]
for index, obtained, expected in message_list: for index, obtained, expected in message_list:
max_sizes[0] = max(max_sizes[0], len(index)) max_sizes[0] = max(max_sizes[0], len(index))
max_sizes[1] = max(max_sizes[1], len(obtained)) max_sizes[1] = max(max_sizes[1], len(obtained))
max_sizes[2] = max(max_sizes[2], len(expected)) max_sizes[2] = max(max_sizes[2], len(expected))
explanation = [ explanation = [
f"comparison failed. Mismatched elements: {len(different_ids)} / {number_of_elements}:", f'comparison failed. Mismatched elements: {len(different_ids)} / {number_of_elements}:',
f"Max absolute difference: {max_abs_diff}", f'Max absolute difference: {max_abs_diff}',
f"Max relative difference: {max_rel_diff}", f'Max relative difference: {max_rel_diff}',
] + [ ] + [
f"{indexes:<{max_sizes[0]}} | {obtained:<{max_sizes[1]}} | {expected:<{max_sizes[2]}}" f'{indexes:<{max_sizes[0]}} | {obtained:<{max_sizes[1]}} | {expected:<{max_sizes[2]}}'
for indexes, obtained, expected in message_list for indexes, obtained, expected in message_list
] ]
return explanation return explanation
@ -79,11 +81,11 @@ class ApproxBase:
def __repr__(self) -> str: def __repr__(self) -> str:
raise NotImplementedError raise NotImplementedError
def _repr_compare(self, other_side: Any) -> List[str]: def _repr_compare(self, other_side: Any) -> list[str]:
return [ return [
"comparison failed", 'comparison failed',
f"Obtained: {other_side}", f'Obtained: {other_side}',
f"Expected: {self}", f'Expected: {self}',
] ]
def __eq__(self, actual) -> bool: def __eq__(self, actual) -> bool:
@ -94,7 +96,7 @@ class ApproxBase:
def __bool__(self): def __bool__(self):
__tracebackhide__ = True __tracebackhide__ = True
raise AssertionError( raise AssertionError(
"approx() is not supported in a boolean context.\nDid you mean: `assert a == approx(b)`?" 'approx() is not supported in a boolean context.\nDid you mean: `assert a == approx(b)`?',
) )
# Ignore type because of https://github.com/python/mypy/issues/4266. # Ignore type because of https://github.com/python/mypy/issues/4266.
@ -103,7 +105,7 @@ class ApproxBase:
def __ne__(self, actual) -> bool: def __ne__(self, actual) -> bool:
return not (actual == self) return not (actual == self)
def _approx_scalar(self, x) -> "ApproxScalar": def _approx_scalar(self, x) -> ApproxScalar:
if isinstance(x, Decimal): if isinstance(x, Decimal):
return ApproxDecimal(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) return ApproxDecimal(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
@ -138,16 +140,16 @@ class ApproxNumpy(ApproxBase):
def __repr__(self) -> str: def __repr__(self) -> str:
list_scalars = _recursive_sequence_map( list_scalars = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist() self._approx_scalar, self.expected.tolist(),
) )
return f"approx({list_scalars!r})" return f'approx({list_scalars!r})'
def _repr_compare(self, other_side: "ndarray") -> List[str]: def _repr_compare(self, other_side: ndarray) -> list[str]:
import itertools import itertools
import math import math
def get_value_from_nested_list( def get_value_from_nested_list(
nested_list: List[Any], nd_index: Tuple[Any, ...] nested_list: list[Any], nd_index: tuple[Any, ...],
) -> Any: ) -> Any:
""" """
Helper function to get the value out of a nested list, given an n-dimensional index. Helper function to get the value out of a nested list, given an n-dimensional index.
@ -160,13 +162,13 @@ class ApproxNumpy(ApproxBase):
np_array_shape = self.expected.shape np_array_shape = self.expected.shape
approx_side_as_seq = _recursive_sequence_map( approx_side_as_seq = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist() self._approx_scalar, self.expected.tolist(),
) )
if np_array_shape != other_side.shape: if np_array_shape != other_side.shape:
return [ return [
"Impossible to compare arrays with different shapes.", 'Impossible to compare arrays with different shapes.',
f"Shapes: {np_array_shape} and {other_side.shape}", f'Shapes: {np_array_shape} and {other_side.shape}',
] ]
number_of_elements = self.expected.size number_of_elements = self.expected.size
@ -238,9 +240,9 @@ class ApproxMapping(ApproxBase):
with numeric values (the keys can be anything).""" with numeric values (the keys can be anything)."""
def __repr__(self) -> str: def __repr__(self) -> str:
return f"approx({({k: self._approx_scalar(v) for k, v in self.expected.items()})!r})" return f'approx({({k: self._approx_scalar(v) for k, v in self.expected.items()})!r})'
def _repr_compare(self, other_side: Mapping[object, float]) -> List[str]: def _repr_compare(self, other_side: Mapping[object, float]) -> list[str]:
import math import math
approx_side_as_map = { approx_side_as_map = {
@ -252,12 +254,12 @@ class ApproxMapping(ApproxBase):
max_rel_diff = -math.inf max_rel_diff = -math.inf
different_ids = [] different_ids = []
for (approx_key, approx_value), other_value in zip( for (approx_key, approx_value), other_value in zip(
approx_side_as_map.items(), other_side.values() approx_side_as_map.items(), other_side.values(),
): ):
if approx_value != other_value: if approx_value != other_value:
if approx_value.expected is not None and other_value is not None: if approx_value.expected is not None and other_value is not None:
max_abs_diff = max( max_abs_diff = max(
max_abs_diff, abs(approx_value.expected - other_value) max_abs_diff, abs(approx_value.expected - other_value),
) )
if approx_value.expected == 0.0: if approx_value.expected == 0.0:
max_rel_diff = math.inf max_rel_diff = math.inf
@ -265,8 +267,8 @@ class ApproxMapping(ApproxBase):
max_rel_diff = max( max_rel_diff = max(
max_rel_diff, max_rel_diff,
abs( abs(
(approx_value.expected - other_value) (approx_value.expected - other_value) /
/ approx_value.expected approx_value.expected,
), ),
) )
different_ids.append(approx_key) different_ids.append(approx_key)
@ -302,7 +304,7 @@ class ApproxMapping(ApproxBase):
__tracebackhide__ = True __tracebackhide__ = True
for key, value in self.expected.items(): for key, value in self.expected.items():
if isinstance(value, type(self.expected)): if isinstance(value, type(self.expected)):
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}" msg = 'pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}'
raise TypeError(msg.format(key, value, pprint.pformat(self.expected))) raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
@ -313,15 +315,15 @@ class ApproxSequenceLike(ApproxBase):
seq_type = type(self.expected) seq_type = type(self.expected)
if seq_type not in (tuple, list): if seq_type not in (tuple, list):
seq_type = list seq_type = list
return f"approx({seq_type(self._approx_scalar(x) for x in self.expected)!r})" return f'approx({seq_type(self._approx_scalar(x) for x in self.expected)!r})'
def _repr_compare(self, other_side: Sequence[float]) -> List[str]: def _repr_compare(self, other_side: Sequence[float]) -> list[str]:
import math import math
if len(self.expected) != len(other_side): if len(self.expected) != len(other_side):
return [ return [
"Impossible to compare lists with different sizes.", 'Impossible to compare lists with different sizes.',
f"Lengths: {len(self.expected)} and {len(other_side)}", f'Lengths: {len(self.expected)} and {len(other_side)}',
] ]
approx_side_as_map = _recursive_sequence_map(self._approx_scalar, self.expected) approx_side_as_map = _recursive_sequence_map(self._approx_scalar, self.expected)
@ -331,7 +333,7 @@ class ApproxSequenceLike(ApproxBase):
max_rel_diff = -math.inf max_rel_diff = -math.inf
different_ids = [] different_ids = []
for i, (approx_value, other_value) in enumerate( for i, (approx_value, other_value) in enumerate(
zip(approx_side_as_map, other_side) zip(approx_side_as_map, other_side),
): ):
if approx_value != other_value: if approx_value != other_value:
abs_diff = abs(approx_value.expected - other_value) abs_diff = abs(approx_value.expected - other_value)
@ -371,7 +373,7 @@ class ApproxSequenceLike(ApproxBase):
__tracebackhide__ = True __tracebackhide__ = True
for index, x in enumerate(self.expected): for index, x in enumerate(self.expected):
if isinstance(x, type(self.expected)): if isinstance(x, type(self.expected)):
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}" msg = 'pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}'
raise TypeError(msg.format(x, index, pprint.pformat(self.expected))) raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
@ -380,8 +382,8 @@ class ApproxScalar(ApproxBase):
# Using Real should be better than this Union, but not possible yet: # Using Real should be better than this Union, but not possible yet:
# https://github.com/python/typeshed/pull/3108 # https://github.com/python/typeshed/pull/3108
DEFAULT_ABSOLUTE_TOLERANCE: Union[float, Decimal] = 1e-12 DEFAULT_ABSOLUTE_TOLERANCE: float | Decimal = 1e-12
DEFAULT_RELATIVE_TOLERANCE: Union[float, Decimal] = 1e-6 DEFAULT_RELATIVE_TOLERANCE: float | Decimal = 1e-6
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return a string communicating both the expected value and the """Return a string communicating both the expected value and the
@ -393,24 +395,24 @@ class ApproxScalar(ApproxBase):
# tolerances, i.e. non-numerics and infinities. Need to call abs to # tolerances, i.e. non-numerics and infinities. Need to call abs to
# handle complex numbers, e.g. (inf + 1j). # handle complex numbers, e.g. (inf + 1j).
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf( if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
abs(self.expected) # type: ignore[arg-type] abs(self.expected), # type: ignore[arg-type]
): ):
return str(self.expected) return str(self.expected)
# If a sensible tolerance can't be calculated, self.tolerance will # If a sensible tolerance can't be calculated, self.tolerance will
# raise a ValueError. In this case, display '???'. # raise a ValueError. In this case, display '???'.
try: try:
vetted_tolerance = f"{self.tolerance:.1e}" vetted_tolerance = f'{self.tolerance:.1e}'
if ( if (
isinstance(self.expected, Complex) isinstance(self.expected, Complex) and
and self.expected.imag self.expected.imag and
and not math.isinf(self.tolerance) not math.isinf(self.tolerance)
): ):
vetted_tolerance += " ∠ ±180°" vetted_tolerance += ' ∠ ±180°'
except ValueError: except ValueError:
vetted_tolerance = "???" vetted_tolerance = '???'
return f"{self.expected} ± {vetted_tolerance}" return f'{self.expected} ± {vetted_tolerance}'
def __eq__(self, actual) -> bool: def __eq__(self, actual) -> bool:
"""Return whether the given value is equal to the expected value """Return whether the given value is equal to the expected value
@ -429,8 +431,8 @@ class ApproxScalar(ApproxBase):
# NB: we need Complex, rather than just Number, to ensure that __abs__, # NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined. # __sub__, and __float__ are defined.
if not ( if not (
isinstance(self.expected, (Complex, Decimal)) isinstance(self.expected, (Complex, Decimal)) and
and isinstance(actual, (Complex, Decimal)) isinstance(actual, (Complex, Decimal))
): ):
return False return False
@ -473,7 +475,7 @@ class ApproxScalar(ApproxBase):
if absolute_tolerance < 0: if absolute_tolerance < 0:
raise ValueError( raise ValueError(
f"absolute tolerance can't be negative: {absolute_tolerance}" f"absolute tolerance can't be negative: {absolute_tolerance}",
) )
if math.isnan(absolute_tolerance): if math.isnan(absolute_tolerance):
raise ValueError("absolute tolerance can't be NaN.") raise ValueError("absolute tolerance can't be NaN.")
@ -490,12 +492,12 @@ class ApproxScalar(ApproxBase):
# because we don't want to raise errors about the relative tolerance if # because we don't want to raise errors about the relative tolerance if
# we aren't even going to use it. # we aren't even going to use it.
relative_tolerance = set_default( relative_tolerance = set_default(
self.rel, self.DEFAULT_RELATIVE_TOLERANCE self.rel, self.DEFAULT_RELATIVE_TOLERANCE,
) * abs(self.expected) ) * abs(self.expected)
if relative_tolerance < 0: if relative_tolerance < 0:
raise ValueError( raise ValueError(
f"relative tolerance can't be negative: {relative_tolerance}" f"relative tolerance can't be negative: {relative_tolerance}",
) )
if math.isnan(relative_tolerance): if math.isnan(relative_tolerance):
raise ValueError("relative tolerance can't be NaN.") raise ValueError("relative tolerance can't be NaN.")
@ -507,8 +509,8 @@ class ApproxScalar(ApproxBase):
class ApproxDecimal(ApproxScalar): class ApproxDecimal(ApproxScalar):
"""Perform approximate comparisons where the expected value is a Decimal.""" """Perform approximate comparisons where the expected value is a Decimal."""
DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12") DEFAULT_ABSOLUTE_TOLERANCE = Decimal('1e-12')
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6") DEFAULT_RELATIVE_TOLERANCE = Decimal('1e-6')
def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase: def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
@ -711,20 +713,20 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
__tracebackhide__ = True __tracebackhide__ = True
if isinstance(expected, Decimal): if isinstance(expected, Decimal):
cls: Type[ApproxBase] = ApproxDecimal cls: type[ApproxBase] = ApproxDecimal
elif isinstance(expected, Mapping): elif isinstance(expected, Mapping):
cls = ApproxMapping cls = ApproxMapping
elif _is_numpy_array(expected): elif _is_numpy_array(expected):
expected = _as_numpy_array(expected) expected = _as_numpy_array(expected)
cls = ApproxNumpy cls = ApproxNumpy
elif ( elif (
hasattr(expected, "__getitem__") hasattr(expected, '__getitem__') and
and isinstance(expected, Sized) isinstance(expected, Sized) and
and not isinstance(expected, (str, bytes)) not isinstance(expected, (str, bytes))
): ):
cls = ApproxSequenceLike cls = ApproxSequenceLike
elif isinstance(expected, Collection) and not isinstance(expected, (str, bytes)): elif isinstance(expected, Collection) and not isinstance(expected, (str, bytes)):
msg = f"pytest.approx() only supports ordered sequences, but got: {expected!r}" msg = f'pytest.approx() only supports ordered sequences, but got: {expected!r}'
raise TypeError(msg) raise TypeError(msg)
else: else:
cls = ApproxScalar cls = ApproxScalar
@ -740,42 +742,42 @@ def _is_numpy_array(obj: object) -> bool:
return _as_numpy_array(obj) is not None return _as_numpy_array(obj) is not None
def _as_numpy_array(obj: object) -> Optional["ndarray"]: def _as_numpy_array(obj: object) -> ndarray | None:
""" """
Return an ndarray if the given object is implicitly convertible to ndarray, Return an ndarray if the given object is implicitly convertible to ndarray,
and numpy is already imported, otherwise None. and numpy is already imported, otherwise None.
""" """
import sys import sys
np: Any = sys.modules.get("numpy") np: Any = sys.modules.get('numpy')
if np is not None: if np is not None:
# avoid infinite recursion on numpy scalars, which have __array__ # avoid infinite recursion on numpy scalars, which have __array__
if np.isscalar(obj): if np.isscalar(obj):
return None return None
elif isinstance(obj, np.ndarray): elif isinstance(obj, np.ndarray):
return obj return obj
elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"): elif hasattr(obj, '__array__') or hasattr('obj', '__array_interface__'):
return np.asarray(obj) return np.asarray(obj)
return None return None
# builtin pytest.raises helper # builtin pytest.raises helper
E = TypeVar("E", bound=BaseException) E = TypeVar('E', bound=BaseException)
@overload @overload
def raises( def raises(
expected_exception: Union[Type[E], Tuple[Type[E], ...]], expected_exception: type[E] | tuple[type[E], ...],
*, *,
match: Optional[Union[str, Pattern[str]]] = ..., match: str | Pattern[str] | None = ...,
) -> "RaisesContext[E]": ) -> RaisesContext[E]:
... ...
@overload @overload
def raises( def raises(
expected_exception: Union[Type[E], Tuple[Type[E], ...]], expected_exception: type[E] | tuple[type[E], ...],
func: Callable[..., Any], func: Callable[..., Any],
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
@ -784,8 +786,8 @@ def raises(
def raises( def raises(
expected_exception: Union[Type[E], Tuple[Type[E], ...]], *args: Any, **kwargs: Any expected_exception: type[E] | tuple[type[E], ...], *args: Any, **kwargs: Any,
) -> Union["RaisesContext[E]", _pytest._code.ExceptionInfo[E]]: ) -> RaisesContext[E] | _pytest._code.ExceptionInfo[E]:
r"""Assert that a code block/function call raises an exception type, or one of its subclasses. r"""Assert that a code block/function call raises an exception type, or one of its subclasses.
:param expected_exception: :param expected_exception:
@ -928,34 +930,34 @@ def raises(
if not expected_exception: if not expected_exception:
raise ValueError( raise ValueError(
f"Expected an exception type or a tuple of exception types, but got `{expected_exception!r}`. " f'Expected an exception type or a tuple of exception types, but got `{expected_exception!r}`. '
f"Raising exceptions is already understood as failing the test, so you don't need " f"Raising exceptions is already understood as failing the test, so you don't need "
f"any special code to say 'this should never raise an exception'." f"any special code to say 'this should never raise an exception'.",
) )
if isinstance(expected_exception, type): if isinstance(expected_exception, type):
expected_exceptions: Tuple[Type[E], ...] = (expected_exception,) expected_exceptions: tuple[type[E], ...] = (expected_exception,)
else: else:
expected_exceptions = expected_exception expected_exceptions = expected_exception
for exc in expected_exceptions: for exc in expected_exceptions:
if not isinstance(exc, type) or not issubclass(exc, BaseException): if not isinstance(exc, type) or not issubclass(exc, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable] msg = 'expected exception must be a BaseException type, not {}' # type: ignore[unreachable]
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__ not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
raise TypeError(msg.format(not_a)) raise TypeError(msg.format(not_a))
message = f"DID NOT RAISE {expected_exception}" message = f'DID NOT RAISE {expected_exception}'
if not args: if not args:
match: Optional[Union[str, Pattern[str]]] = kwargs.pop("match", None) match: str | Pattern[str] | None = kwargs.pop('match', None)
if kwargs: if kwargs:
msg = "Unexpected keyword arguments passed to pytest.raises: " msg = 'Unexpected keyword arguments passed to pytest.raises: '
msg += ", ".join(sorted(kwargs)) msg += ', '.join(sorted(kwargs))
msg += "\nUse context-manager form instead?" msg += '\nUse context-manager form instead?'
raise TypeError(msg) raise TypeError(msg)
return RaisesContext(expected_exception, message, match) return RaisesContext(expected_exception, message, match)
else: else:
func = args[0] func = args[0]
if not callable(func): if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable") raise TypeError(f'{func!r} object (type: {type(func)}) must be callable')
try: try:
func(*args[1:], **kwargs) func(*args[1:], **kwargs)
except expected_exception as e: except expected_exception as e:
@ -971,14 +973,14 @@ raises.Exception = fail.Exception # type: ignore
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]): class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__( def __init__(
self, self,
expected_exception: Union[Type[E], Tuple[Type[E], ...]], expected_exception: type[E] | tuple[type[E], ...],
message: str, message: str,
match_expr: Optional[Union[str, Pattern[str]]] = None, match_expr: str | Pattern[str] | None = None,
) -> None: ) -> None:
self.expected_exception = expected_exception self.expected_exception = expected_exception
self.message = message self.message = message
self.match_expr = match_expr self.match_expr = match_expr
self.excinfo: Optional[_pytest._code.ExceptionInfo[E]] = None self.excinfo: _pytest._code.ExceptionInfo[E] | None = None
def __enter__(self) -> _pytest._code.ExceptionInfo[E]: def __enter__(self) -> _pytest._code.ExceptionInfo[E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later() self.excinfo = _pytest._code.ExceptionInfo.for_later()
@ -986,9 +988,9 @@ class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> bool: ) -> bool:
__tracebackhide__ = True __tracebackhide__ = True
if exc_type is None: if exc_type is None:

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import sys import sys
import pytest import pytest
@ -6,19 +8,19 @@ from pytest import Parser
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
parser.addini("pythonpath", type="paths", help="Add paths to sys.path", default=[]) parser.addini('pythonpath', type='paths', help='Add paths to sys.path', default=[])
@pytest.hookimpl(tryfirst=True) @pytest.hookimpl(tryfirst=True)
def pytest_load_initial_conftests(early_config: Config) -> None: def pytest_load_initial_conftests(early_config: Config) -> None:
# `pythonpath = a b` will set `sys.path` to `[a, b, x, y, z, ...]` # `pythonpath = a b` will set `sys.path` to `[a, b, x, y, z, ...]`
for path in reversed(early_config.getini("pythonpath")): for path in reversed(early_config.getini('pythonpath')):
sys.path.insert(0, str(path)) sys.path.insert(0, str(path))
@pytest.hookimpl(trylast=True) @pytest.hookimpl(trylast=True)
def pytest_unconfigure(config: Config) -> None: def pytest_unconfigure(config: Config) -> None:
for path in config.getini("pythonpath"): for path in config.getini('pythonpath'):
path_str = str(path) path_str = str(path)
if path_str in sys.path: if path_str in sys.path:
sys.path.remove(path_str) sys.path.remove(path_str)

View file

@ -1,7 +1,10 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Record warnings during test function execution.""" """Record warnings during test function execution."""
from pprint import pformat from __future__ import annotations
import re import re
import warnings
from pprint import pformat
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
@ -16,7 +19,6 @@ from typing import Tuple
from typing import Type from typing import Type
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
import warnings
from _pytest.deprecated import check_ispytest from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture from _pytest.fixtures import fixture
@ -24,11 +26,11 @@ from _pytest.outcomes import Exit
from _pytest.outcomes import fail from _pytest.outcomes import fail
T = TypeVar("T") T = TypeVar('T')
@fixture @fixture
def recwarn() -> Generator["WarningsRecorder", None, None]: def recwarn() -> Generator[WarningsRecorder, None, None]:
"""Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions. """Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions.
See https://docs.pytest.org/en/latest/how-to/capture-warnings.html for information See https://docs.pytest.org/en/latest/how-to/capture-warnings.html for information
@ -36,14 +38,14 @@ def recwarn() -> Generator["WarningsRecorder", None, None]:
""" """
wrec = WarningsRecorder(_ispytest=True) wrec = WarningsRecorder(_ispytest=True)
with wrec: with wrec:
warnings.simplefilter("default") warnings.simplefilter('default')
yield wrec yield wrec
@overload @overload
def deprecated_call( def deprecated_call(
*, match: Optional[Union[str, Pattern[str]]] = ... *, match: str | Pattern[str] | None = ...,
) -> "WarningsRecorder": ) -> WarningsRecorder:
... ...
@ -53,8 +55,8 @@ def deprecated_call(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
def deprecated_call( def deprecated_call(
func: Optional[Callable[..., Any]] = None, *args: Any, **kwargs: Any func: Callable[..., Any] | None = None, *args: Any, **kwargs: Any,
) -> Union["WarningsRecorder", Any]: ) -> WarningsRecorder | Any:
"""Assert that code produces a ``DeprecationWarning`` or ``PendingDeprecationWarning`` or ``FutureWarning``. """Assert that code produces a ``DeprecationWarning`` or ``PendingDeprecationWarning`` or ``FutureWarning``.
This function can be used as a context manager:: This function can be used as a context manager::
@ -82,22 +84,22 @@ def deprecated_call(
if func is not None: if func is not None:
args = (func, *args) args = (func, *args)
return warns( return warns(
(DeprecationWarning, PendingDeprecationWarning, FutureWarning), *args, **kwargs (DeprecationWarning, PendingDeprecationWarning, FutureWarning), *args, **kwargs,
) )
@overload @overload
def warns( def warns(
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = ..., expected_warning: type[Warning] | tuple[type[Warning], ...] = ...,
*, *,
match: Optional[Union[str, Pattern[str]]] = ..., match: str | Pattern[str] | None = ...,
) -> "WarningsChecker": ) -> WarningsChecker:
... ...
@overload @overload
def warns( def warns(
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]], expected_warning: type[Warning] | tuple[type[Warning], ...],
func: Callable[..., T], func: Callable[..., T],
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
@ -106,11 +108,11 @@ def warns(
def warns( def warns(
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = Warning, expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning,
*args: Any, *args: Any,
match: Optional[Union[str, Pattern[str]]] = None, match: str | Pattern[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> Union["WarningsChecker", Any]: ) -> WarningsChecker | Any:
r"""Assert that code raises a particular class of warning. r"""Assert that code raises a particular class of warning.
Specifically, the parameter ``expected_warning`` can be a warning class or tuple Specifically, the parameter ``expected_warning`` can be a warning class or tuple
@ -155,16 +157,16 @@ def warns(
__tracebackhide__ = True __tracebackhide__ = True
if not args: if not args:
if kwargs: if kwargs:
argnames = ", ".join(sorted(kwargs)) argnames = ', '.join(sorted(kwargs))
raise TypeError( raise TypeError(
f"Unexpected keyword arguments passed to pytest.warns: {argnames}" f'Unexpected keyword arguments passed to pytest.warns: {argnames}'
"\nUse context-manager form instead?" '\nUse context-manager form instead?',
) )
return WarningsChecker(expected_warning, match_expr=match, _ispytest=True) return WarningsChecker(expected_warning, match_expr=match, _ispytest=True)
else: else:
func = args[0] func = args[0]
if not callable(func): if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable") raise TypeError(f'{func!r} object (type: {type(func)}) must be callable')
with WarningsChecker(expected_warning, _ispytest=True): with WarningsChecker(expected_warning, _ispytest=True):
return func(*args[1:], **kwargs) return func(*args[1:], **kwargs)
@ -187,18 +189,18 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
# Type ignored due to the way typeshed handles warnings.catch_warnings. # Type ignored due to the way typeshed handles warnings.catch_warnings.
super().__init__(record=True) # type: ignore[call-arg] super().__init__(record=True) # type: ignore[call-arg]
self._entered = False self._entered = False
self._list: List[warnings.WarningMessage] = [] self._list: list[warnings.WarningMessage] = []
@property @property
def list(self) -> List["warnings.WarningMessage"]: def list(self) -> list[warnings.WarningMessage]:
"""The list of recorded warnings.""" """The list of recorded warnings."""
return self._list return self._list
def __getitem__(self, i: int) -> "warnings.WarningMessage": def __getitem__(self, i: int) -> warnings.WarningMessage:
"""Get a recorded warning by index.""" """Get a recorded warning by index."""
return self._list[i] return self._list[i]
def __iter__(self) -> Iterator["warnings.WarningMessage"]: def __iter__(self) -> Iterator[warnings.WarningMessage]:
"""Iterate through the recorded warnings.""" """Iterate through the recorded warnings."""
return iter(self._list) return iter(self._list)
@ -206,24 +208,24 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
"""The number of recorded warnings.""" """The number of recorded warnings."""
return len(self._list) return len(self._list)
def pop(self, cls: Type[Warning] = Warning) -> "warnings.WarningMessage": def pop(self, cls: type[Warning] = Warning) -> warnings.WarningMessage:
"""Pop the first recorded warning which is an instance of ``cls``, """Pop the first recorded warning which is an instance of ``cls``,
but not an instance of a child class of any other match. but not an instance of a child class of any other match.
Raises ``AssertionError`` if there is no match. Raises ``AssertionError`` if there is no match.
""" """
best_idx: Optional[int] = None best_idx: int | None = None
for i, w in enumerate(self._list): for i, w in enumerate(self._list):
if w.category == cls: if w.category == cls:
return self._list.pop(i) # exact match, stop looking return self._list.pop(i) # exact match, stop looking
if issubclass(w.category, cls) and ( if issubclass(w.category, cls) and (
best_idx is None best_idx is None or
or not issubclass(w.category, self._list[best_idx].category) not issubclass(w.category, self._list[best_idx].category)
): ):
best_idx = i best_idx = i
if best_idx is not None: if best_idx is not None:
return self._list.pop(best_idx) return self._list.pop(best_idx)
__tracebackhide__ = True __tracebackhide__ = True
raise AssertionError(f"{cls!r} not found in warning list") raise AssertionError(f'{cls!r} not found in warning list')
def clear(self) -> None: def clear(self) -> None:
"""Clear the list of recorded warnings.""" """Clear the list of recorded warnings."""
@ -231,26 +233,26 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__ # Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
# -- it returns a List but we only emulate one. # -- it returns a List but we only emulate one.
def __enter__(self) -> "WarningsRecorder": # type: ignore def __enter__(self) -> WarningsRecorder: # type: ignore
if self._entered: if self._entered:
__tracebackhide__ = True __tracebackhide__ = True
raise RuntimeError(f"Cannot enter {self!r} twice") raise RuntimeError(f'Cannot enter {self!r} twice')
_list = super().__enter__() _list = super().__enter__()
# record=True means it's None. # record=True means it's None.
assert _list is not None assert _list is not None
self._list = _list self._list = _list
warnings.simplefilter("always") warnings.simplefilter('always')
return self return self
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> None: ) -> None:
if not self._entered: if not self._entered:
__tracebackhide__ = True __tracebackhide__ = True
raise RuntimeError(f"Cannot exit {self!r} without entering first") raise RuntimeError(f'Cannot exit {self!r} without entering first')
super().__exit__(exc_type, exc_val, exc_tb) super().__exit__(exc_type, exc_val, exc_tb)
@ -263,22 +265,22 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
class WarningsChecker(WarningsRecorder): class WarningsChecker(WarningsRecorder):
def __init__( def __init__(
self, self,
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = Warning, expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning,
match_expr: Optional[Union[str, Pattern[str]]] = None, match_expr: str | Pattern[str] | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
super().__init__(_ispytest=True) super().__init__(_ispytest=True)
msg = "exceptions must be derived from Warning, not %s" msg = 'exceptions must be derived from Warning, not %s'
if isinstance(expected_warning, tuple): if isinstance(expected_warning, tuple):
for exc in expected_warning: for exc in expected_warning:
if not issubclass(exc, Warning): if not issubclass(exc, Warning):
raise TypeError(msg % type(exc)) raise TypeError(msg % type(exc))
expected_warning_tup = expected_warning expected_warning_tup = expected_warning
elif isinstance(expected_warning, type) and issubclass( elif isinstance(expected_warning, type) and issubclass(
expected_warning, Warning expected_warning, Warning,
): ):
expected_warning_tup = (expected_warning,) expected_warning_tup = (expected_warning,)
else: else:
@ -290,14 +292,14 @@ class WarningsChecker(WarningsRecorder):
def matches(self, warning: warnings.WarningMessage) -> bool: def matches(self, warning: warnings.WarningMessage) -> bool:
assert self.expected_warning is not None assert self.expected_warning is not None
return issubclass(warning.category, self.expected_warning) and bool( return issubclass(warning.category, self.expected_warning) and bool(
self.match_expr is None or re.search(self.match_expr, str(warning.message)) self.match_expr is None or re.search(self.match_expr, str(warning.message)),
) )
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> None: ) -> None:
super().__exit__(exc_type, exc_val, exc_tb) super().__exit__(exc_type, exc_val, exc_tb)
@ -308,9 +310,9 @@ class WarningsChecker(WarningsRecorder):
# when the warning doesn't happen. Control-flow exceptions should always # when the warning doesn't happen. Control-flow exceptions should always
# propagate. # propagate.
if exc_val is not None and ( if exc_val is not None and (
not isinstance(exc_val, Exception) not isinstance(exc_val, Exception) or
# Exit is an Exception, not a BaseException, for some reason. # Exit is an Exception, not a BaseException, for some reason.
or isinstance(exc_val, Exit) isinstance(exc_val, Exit)
): ):
return return
@ -320,14 +322,14 @@ class WarningsChecker(WarningsRecorder):
try: try:
if not any(issubclass(w.category, self.expected_warning) for w in self): if not any(issubclass(w.category, self.expected_warning) for w in self):
fail( fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n" f'DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n'
f" Emitted warnings: {found_str()}." f' Emitted warnings: {found_str()}.',
) )
elif not any(self.matches(w) for w in self): elif not any(self.matches(w) for w in self):
fail( fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.\n" f'DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.\n'
f" Regex: {self.match_expr}\n" f' Regex: {self.match_expr}\n'
f" Emitted warnings: {found_str()}." f' Emitted warnings: {found_str()}.',
) )
finally: finally:
# Whether or not any warnings matched, we want to re-emit all unmatched warnings. # Whether or not any warnings matched, we want to re-emit all unmatched warnings.
@ -366,5 +368,5 @@ class WarningsChecker(WarningsRecorder):
# its first argument was not a string. But that case can't be # its first argument was not a string. But that case can't be
# distinguished from an invalid type. # distinguished from an invalid type.
raise TypeError( raise TypeError(
f"Warning must be str or Warning, got {msg!r} (type {type(msg).__name__})" f'Warning must be str or Warning, got {msg!r} (type {type(msg).__name__})',
) )

View file

@ -1,7 +1,9 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import dataclasses import dataclasses
from io import StringIO
import os import os
from io import StringIO
from pprint import pprint from pprint import pprint
from typing import Any from typing import Any
from typing import cast from typing import cast
@ -47,25 +49,25 @@ def getworkerinfoline(node):
return node._workerinfocache return node._workerinfocache
except AttributeError: except AttributeError:
d = node.workerinfo d = node.workerinfo
ver = "{}.{}.{}".format(*d["version_info"][:3]) ver = '{}.{}.{}'.format(*d['version_info'][:3])
node._workerinfocache = s = "[{}] {} -- Python {} {}".format( node._workerinfocache = s = '[{}] {} -- Python {} {}'.format(
d["id"], d["sysplatform"], ver, d["executable"] d['id'], d['sysplatform'], ver, d['executable'],
) )
return s return s
_R = TypeVar("_R", bound="BaseReport") _R = TypeVar('_R', bound='BaseReport')
class BaseReport: class BaseReport:
when: Optional[str] when: str | None
location: Optional[Tuple[str, Optional[int], str]] location: tuple[str, int | None, str] | None
longrepr: Union[ longrepr: (
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr None | ExceptionInfo[BaseException] | tuple[str, int, str] | str | TerminalRepr
] )
sections: List[Tuple[str, str]] sections: list[tuple[str, str]]
nodeid: str nodeid: str
outcome: Literal["passed", "failed", "skipped"] outcome: Literal['passed', 'failed', 'skipped']
def __init__(self, **kw: Any) -> None: def __init__(self, **kw: Any) -> None:
self.__dict__.update(kw) self.__dict__.update(kw)
@ -76,7 +78,7 @@ class BaseReport:
... ...
def toterminal(self, out: TerminalWriter) -> None: def toterminal(self, out: TerminalWriter) -> None:
if hasattr(self, "node"): if hasattr(self, 'node'):
worker_info = getworkerinfoline(self.node) worker_info = getworkerinfoline(self.node)
if worker_info: if worker_info:
out.line(worker_info) out.line(worker_info)
@ -85,17 +87,17 @@ class BaseReport:
if longrepr is None: if longrepr is None:
return return
if hasattr(longrepr, "toterminal"): if hasattr(longrepr, 'toterminal'):
longrepr_terminal = cast(TerminalRepr, longrepr) longrepr_terminal = cast(TerminalRepr, longrepr)
longrepr_terminal.toterminal(out) longrepr_terminal.toterminal(out)
else: else:
try: try:
s = str(longrepr) s = str(longrepr)
except UnicodeEncodeError: except UnicodeEncodeError:
s = "<unprintable longrepr>" s = '<unprintable longrepr>'
out.line(s) out.line(s)
def get_sections(self, prefix: str) -> Iterator[Tuple[str, str]]: def get_sections(self, prefix: str) -> Iterator[tuple[str, str]]:
for name, content in self.sections: for name, content in self.sections:
if name.startswith(prefix): if name.startswith(prefix):
yield prefix, content yield prefix, content
@ -120,8 +122,8 @@ class BaseReport:
.. versionadded:: 3.5 .. versionadded:: 3.5
""" """
return "\n".join( return '\n'.join(
content for (prefix, content) in self.get_sections("Captured log") content for (prefix, content) in self.get_sections('Captured log')
) )
@property @property
@ -130,8 +132,8 @@ class BaseReport:
.. versionadded:: 3.0 .. versionadded:: 3.0
""" """
return "".join( return ''.join(
content for (prefix, content) in self.get_sections("Captured stdout") content for (prefix, content) in self.get_sections('Captured stdout')
) )
@property @property
@ -140,29 +142,29 @@ class BaseReport:
.. versionadded:: 3.0 .. versionadded:: 3.0
""" """
return "".join( return ''.join(
content for (prefix, content) in self.get_sections("Captured stderr") content for (prefix, content) in self.get_sections('Captured stderr')
) )
@property @property
def passed(self) -> bool: def passed(self) -> bool:
"""Whether the outcome is passed.""" """Whether the outcome is passed."""
return self.outcome == "passed" return self.outcome == 'passed'
@property @property
def failed(self) -> bool: def failed(self) -> bool:
"""Whether the outcome is failed.""" """Whether the outcome is failed."""
return self.outcome == "failed" return self.outcome == 'failed'
@property @property
def skipped(self) -> bool: def skipped(self) -> bool:
"""Whether the outcome is skipped.""" """Whether the outcome is skipped."""
return self.outcome == "skipped" return self.outcome == 'skipped'
@property @property
def fspath(self) -> str: def fspath(self) -> str:
"""The path portion of the reported node, as a string.""" """The path portion of the reported node, as a string."""
return self.nodeid.split("::")[0] return self.nodeid.split('::')[0]
@property @property
def count_towards_summary(self) -> bool: def count_towards_summary(self) -> bool:
@ -177,7 +179,7 @@ class BaseReport:
return True return True
@property @property
def head_line(self) -> Optional[str]: def head_line(self) -> str | None:
"""**Experimental** The head line shown with longrepr output for this """**Experimental** The head line shown with longrepr output for this
report, more commonly during traceback representation during report, more commonly during traceback representation during
failures:: failures::
@ -199,11 +201,11 @@ class BaseReport:
def _get_verbose_word(self, config: Config): def _get_verbose_word(self, config: Config):
_category, _short, verbose = config.hook.pytest_report_teststatus( _category, _short, verbose = config.hook.pytest_report_teststatus(
report=self, config=config report=self, config=config,
) )
return verbose return verbose
def _to_json(self) -> Dict[str, Any]: def _to_json(self) -> dict[str, Any]:
"""Return the contents of this report as a dict of builtin entries, """Return the contents of this report as a dict of builtin entries,
suitable for serialization. suitable for serialization.
@ -214,7 +216,7 @@ class BaseReport:
return _report_to_json(self) return _report_to_json(self)
@classmethod @classmethod
def _from_json(cls: Type[_R], reportdict: Dict[str, object]) -> _R: def _from_json(cls: type[_R], reportdict: dict[str, object]) -> _R:
"""Create either a TestReport or CollectReport, depending on the calling class. """Create either a TestReport or CollectReport, depending on the calling class.
It is the callers responsibility to know which class to pass here. It is the callers responsibility to know which class to pass here.
@ -228,16 +230,16 @@ class BaseReport:
def _report_unserialization_failure( def _report_unserialization_failure(
type_name: str, report_class: Type[BaseReport], reportdict type_name: str, report_class: type[BaseReport], reportdict,
) -> NoReturn: ) -> NoReturn:
url = "https://github.com/pytest-dev/pytest/issues" url = 'https://github.com/pytest-dev/pytest/issues'
stream = StringIO() stream = StringIO()
pprint("-" * 100, stream=stream) pprint('-' * 100, stream=stream)
pprint("INTERNALERROR: Unknown entry type returned: %s" % type_name, stream=stream) pprint('INTERNALERROR: Unknown entry type returned: %s' % type_name, stream=stream)
pprint("report_name: %s" % report_class, stream=stream) pprint('report_name: %s' % report_class, stream=stream)
pprint(reportdict, stream=stream) pprint(reportdict, stream=stream)
pprint("Please report this bug at %s" % url, stream=stream) pprint('Please report this bug at %s' % url, stream=stream)
pprint("-" * 100, stream=stream) pprint('-' * 100, stream=stream)
raise RuntimeError(stream.getvalue()) raise RuntimeError(stream.getvalue())
@ -257,18 +259,18 @@ class TestReport(BaseReport):
def __init__( def __init__(
self, self,
nodeid: str, nodeid: str,
location: Tuple[str, Optional[int], str], location: tuple[str, int | None, str],
keywords: Mapping[str, Any], keywords: Mapping[str, Any],
outcome: Literal["passed", "failed", "skipped"], outcome: Literal['passed', 'failed', 'skipped'],
longrepr: Union[ longrepr: (
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr None | ExceptionInfo[BaseException] | tuple[str, int, str] | str | TerminalRepr
], ),
when: Literal["setup", "call", "teardown"], when: Literal['setup', 'call', 'teardown'],
sections: Iterable[Tuple[str, str]] = (), sections: Iterable[tuple[str, str]] = (),
duration: float = 0, duration: float = 0,
start: float = 0, start: float = 0,
stop: float = 0, stop: float = 0,
user_properties: Optional[Iterable[Tuple[str, object]]] = None, user_properties: Iterable[tuple[str, object]] | None = None,
**extra, **extra,
) -> None: ) -> None:
#: Normalized collection nodeid. #: Normalized collection nodeid.
@ -279,7 +281,7 @@ class TestReport(BaseReport):
#: collected one e.g. if a method is inherited from a different module. #: collected one e.g. if a method is inherited from a different module.
#: The filesystempath may be relative to ``config.rootdir``. #: The filesystempath may be relative to ``config.rootdir``.
#: The line number is 0-based. #: The line number is 0-based.
self.location: Tuple[str, Optional[int], str] = location self.location: tuple[str, int | None, str] = location
#: A name -> value dictionary containing all keywords and #: A name -> value dictionary containing all keywords and
#: markers associated with a test invocation. #: markers associated with a test invocation.
@ -315,10 +317,10 @@ class TestReport(BaseReport):
self.__dict__.update(extra) self.__dict__.update(extra)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self.nodeid!r} when={self.when!r} outcome={self.outcome!r}>" return f'<{self.__class__.__name__} {self.nodeid!r} when={self.when!r} outcome={self.outcome!r}>'
@classmethod @classmethod
def from_item_and_call(cls, item: Item, call: "CallInfo[None]") -> "TestReport": def from_item_and_call(cls, item: Item, call: CallInfo[None]) -> TestReport:
"""Create and fill a TestReport with standard item and call info. """Create and fill a TestReport with standard item and call info.
:param item: The item. :param item: The item.
@ -326,7 +328,7 @@ class TestReport(BaseReport):
""" """
when = call.when when = call.when
# Remove "collect" from the Literal type -- only for collection calls. # Remove "collect" from the Literal type -- only for collection calls.
assert when != "collect" assert when != 'collect'
duration = call.duration duration = call.duration
start = call.start start = call.start
stop = call.stop stop = call.stop
@ -334,24 +336,24 @@ class TestReport(BaseReport):
excinfo = call.excinfo excinfo = call.excinfo
sections = [] sections = []
if not call.excinfo: if not call.excinfo:
outcome: Literal["passed", "failed", "skipped"] = "passed" outcome: Literal['passed', 'failed', 'skipped'] = 'passed'
longrepr: Union[ longrepr: (
None, None |
ExceptionInfo[BaseException], ExceptionInfo[BaseException] |
Tuple[str, int, str], tuple[str, int, str] |
str, str |
TerminalRepr, TerminalRepr
] = None ) = None
else: else:
if not isinstance(excinfo, ExceptionInfo): if not isinstance(excinfo, ExceptionInfo):
outcome = "failed" outcome = 'failed'
longrepr = excinfo longrepr = excinfo
elif isinstance(excinfo.value, skip.Exception): elif isinstance(excinfo.value, skip.Exception):
outcome = "skipped" outcome = 'skipped'
r = excinfo._getreprcrash() r = excinfo._getreprcrash()
assert ( assert (
r is not None r is not None
), "There should always be a traceback entry for skipping a test." ), 'There should always be a traceback entry for skipping a test.'
if excinfo.value._use_item_location: if excinfo.value._use_item_location:
path, line = item.reportinfo()[:2] path, line = item.reportinfo()[:2]
assert line is not None assert line is not None
@ -359,15 +361,15 @@ class TestReport(BaseReport):
else: else:
longrepr = (str(r.path), r.lineno, r.message) longrepr = (str(r.path), r.lineno, r.message)
else: else:
outcome = "failed" outcome = 'failed'
if call.when == "call": if call.when == 'call':
longrepr = item.repr_failure(excinfo) longrepr = item.repr_failure(excinfo)
else: # exception in setup or teardown else: # exception in setup or teardown
longrepr = item._repr_failure_py( longrepr = item._repr_failure_py(
excinfo, style=item.config.getoption("tbstyle", "auto") excinfo, style=item.config.getoption('tbstyle', 'auto'),
) )
for rwhen, key, content in item._report_sections: for rwhen, key, content in item._report_sections:
sections.append((f"Captured {key} {rwhen}", content)) sections.append((f'Captured {key} {rwhen}', content))
return cls( return cls(
item.nodeid, item.nodeid,
item.location, item.location,
@ -390,17 +392,17 @@ class CollectReport(BaseReport):
Reports can contain arbitrary extra attributes. Reports can contain arbitrary extra attributes.
""" """
when = "collect" when = 'collect'
def __init__( def __init__(
self, self,
nodeid: str, nodeid: str,
outcome: "Literal['passed', 'failed', 'skipped']", outcome: Literal['passed', 'failed', 'skipped'],
longrepr: Union[ longrepr: (
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr None | ExceptionInfo[BaseException] | tuple[str, int, str] | str | TerminalRepr
], ),
result: Optional[List[Union[Item, Collector]]], result: list[Item | Collector] | None,
sections: Iterable[Tuple[str, str]] = (), sections: Iterable[tuple[str, str]] = (),
**extra, **extra,
) -> None: ) -> None:
#: Normalized collection nodeid. #: Normalized collection nodeid.
@ -426,11 +428,11 @@ class CollectReport(BaseReport):
@property @property
def location( # type:ignore[override] def location( # type:ignore[override]
self, self,
) -> Optional[Tuple[str, Optional[int], str]]: ) -> tuple[str, int | None, str] | None:
return (self.fspath, None, self.fspath) return (self.fspath, None, self.fspath)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<CollectReport {self.nodeid!r} lenresult={len(self.result)} outcome={self.outcome!r}>" return f'<CollectReport {self.nodeid!r} lenresult={len(self.result)} outcome={self.outcome!r}>'
class CollectErrorRepr(TerminalRepr): class CollectErrorRepr(TerminalRepr):
@ -442,31 +444,31 @@ class CollectErrorRepr(TerminalRepr):
def pytest_report_to_serializable( def pytest_report_to_serializable(
report: Union[CollectReport, TestReport], report: CollectReport | TestReport,
) -> Optional[Dict[str, Any]]: ) -> dict[str, Any] | None:
if isinstance(report, (TestReport, CollectReport)): if isinstance(report, (TestReport, CollectReport)):
data = report._to_json() data = report._to_json()
data["$report_type"] = report.__class__.__name__ data['$report_type'] = report.__class__.__name__
return data return data
# TODO: Check if this is actually reachable. # TODO: Check if this is actually reachable.
return None # type: ignore[unreachable] return None # type: ignore[unreachable]
def pytest_report_from_serializable( def pytest_report_from_serializable(
data: Dict[str, Any], data: dict[str, Any],
) -> Optional[Union[CollectReport, TestReport]]: ) -> CollectReport | TestReport | None:
if "$report_type" in data: if '$report_type' in data:
if data["$report_type"] == "TestReport": if data['$report_type'] == 'TestReport':
return TestReport._from_json(data) return TestReport._from_json(data)
elif data["$report_type"] == "CollectReport": elif data['$report_type'] == 'CollectReport':
return CollectReport._from_json(data) return CollectReport._from_json(data)
assert False, "Unknown report_type unserialize data: {}".format( assert False, 'Unknown report_type unserialize data: {}'.format(
data["$report_type"] data['$report_type'],
) )
return None return None
def _report_to_json(report: BaseReport) -> Dict[str, Any]: def _report_to_json(report: BaseReport) -> dict[str, Any]:
"""Return the contents of this report as a dict of builtin entries, """Return the contents of this report as a dict of builtin entries,
suitable for serialization. suitable for serialization.
@ -474,72 +476,72 @@ def _report_to_json(report: BaseReport) -> Dict[str, Any]:
""" """
def serialize_repr_entry( def serialize_repr_entry(
entry: Union[ReprEntry, ReprEntryNative], entry: ReprEntry | ReprEntryNative,
) -> Dict[str, Any]: ) -> dict[str, Any]:
data = dataclasses.asdict(entry) data = dataclasses.asdict(entry)
for key, value in data.items(): for key, value in data.items():
if hasattr(value, "__dict__"): if hasattr(value, '__dict__'):
data[key] = dataclasses.asdict(value) data[key] = dataclasses.asdict(value)
entry_data = {"type": type(entry).__name__, "data": data} entry_data = {'type': type(entry).__name__, 'data': data}
return entry_data return entry_data
def serialize_repr_traceback(reprtraceback: ReprTraceback) -> Dict[str, Any]: def serialize_repr_traceback(reprtraceback: ReprTraceback) -> dict[str, Any]:
result = dataclasses.asdict(reprtraceback) result = dataclasses.asdict(reprtraceback)
result["reprentries"] = [ result['reprentries'] = [
serialize_repr_entry(x) for x in reprtraceback.reprentries serialize_repr_entry(x) for x in reprtraceback.reprentries
] ]
return result return result
def serialize_repr_crash( def serialize_repr_crash(
reprcrash: Optional[ReprFileLocation], reprcrash: ReprFileLocation | None,
) -> Optional[Dict[str, Any]]: ) -> dict[str, Any] | None:
if reprcrash is not None: if reprcrash is not None:
return dataclasses.asdict(reprcrash) return dataclasses.asdict(reprcrash)
else: else:
return None return None
def serialize_exception_longrepr(rep: BaseReport) -> Dict[str, Any]: def serialize_exception_longrepr(rep: BaseReport) -> dict[str, Any]:
assert rep.longrepr is not None assert rep.longrepr is not None
# TODO: Investigate whether the duck typing is really necessary here. # TODO: Investigate whether the duck typing is really necessary here.
longrepr = cast(ExceptionRepr, rep.longrepr) longrepr = cast(ExceptionRepr, rep.longrepr)
result: Dict[str, Any] = { result: dict[str, Any] = {
"reprcrash": serialize_repr_crash(longrepr.reprcrash), 'reprcrash': serialize_repr_crash(longrepr.reprcrash),
"reprtraceback": serialize_repr_traceback(longrepr.reprtraceback), 'reprtraceback': serialize_repr_traceback(longrepr.reprtraceback),
"sections": longrepr.sections, 'sections': longrepr.sections,
} }
if isinstance(longrepr, ExceptionChainRepr): if isinstance(longrepr, ExceptionChainRepr):
result["chain"] = [] result['chain'] = []
for repr_traceback, repr_crash, description in longrepr.chain: for repr_traceback, repr_crash, description in longrepr.chain:
result["chain"].append( result['chain'].append(
( (
serialize_repr_traceback(repr_traceback), serialize_repr_traceback(repr_traceback),
serialize_repr_crash(repr_crash), serialize_repr_crash(repr_crash),
description, description,
) ),
) )
else: else:
result["chain"] = None result['chain'] = None
return result return result
d = report.__dict__.copy() d = report.__dict__.copy()
if hasattr(report.longrepr, "toterminal"): if hasattr(report.longrepr, 'toterminal'):
if hasattr(report.longrepr, "reprtraceback") and hasattr( if hasattr(report.longrepr, 'reprtraceback') and hasattr(
report.longrepr, "reprcrash" report.longrepr, 'reprcrash',
): ):
d["longrepr"] = serialize_exception_longrepr(report) d['longrepr'] = serialize_exception_longrepr(report)
else: else:
d["longrepr"] = str(report.longrepr) d['longrepr'] = str(report.longrepr)
else: else:
d["longrepr"] = report.longrepr d['longrepr'] = report.longrepr
for name in d: for name in d:
if isinstance(d[name], os.PathLike): if isinstance(d[name], os.PathLike):
d[name] = os.fspath(d[name]) d[name] = os.fspath(d[name])
elif name == "result": elif name == 'result':
d[name] = None # for now d[name] = None # for now
return d return d
def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]: def _report_kwargs_from_json(reportdict: dict[str, Any]) -> dict[str, Any]:
"""Return **kwargs that can be used to construct a TestReport or """Return **kwargs that can be used to construct a TestReport or
CollectReport instance. CollectReport instance.
@ -547,76 +549,76 @@ def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
""" """
def deserialize_repr_entry(entry_data): def deserialize_repr_entry(entry_data):
data = entry_data["data"] data = entry_data['data']
entry_type = entry_data["type"] entry_type = entry_data['type']
if entry_type == "ReprEntry": if entry_type == 'ReprEntry':
reprfuncargs = None reprfuncargs = None
reprfileloc = None reprfileloc = None
reprlocals = None reprlocals = None
if data["reprfuncargs"]: if data['reprfuncargs']:
reprfuncargs = ReprFuncArgs(**data["reprfuncargs"]) reprfuncargs = ReprFuncArgs(**data['reprfuncargs'])
if data["reprfileloc"]: if data['reprfileloc']:
reprfileloc = ReprFileLocation(**data["reprfileloc"]) reprfileloc = ReprFileLocation(**data['reprfileloc'])
if data["reprlocals"]: if data['reprlocals']:
reprlocals = ReprLocals(data["reprlocals"]["lines"]) reprlocals = ReprLocals(data['reprlocals']['lines'])
reprentry: Union[ReprEntry, ReprEntryNative] = ReprEntry( reprentry: ReprEntry | ReprEntryNative = ReprEntry(
lines=data["lines"], lines=data['lines'],
reprfuncargs=reprfuncargs, reprfuncargs=reprfuncargs,
reprlocals=reprlocals, reprlocals=reprlocals,
reprfileloc=reprfileloc, reprfileloc=reprfileloc,
style=data["style"], style=data['style'],
) )
elif entry_type == "ReprEntryNative": elif entry_type == 'ReprEntryNative':
reprentry = ReprEntryNative(data["lines"]) reprentry = ReprEntryNative(data['lines'])
else: else:
_report_unserialization_failure(entry_type, TestReport, reportdict) _report_unserialization_failure(entry_type, TestReport, reportdict)
return reprentry return reprentry
def deserialize_repr_traceback(repr_traceback_dict): def deserialize_repr_traceback(repr_traceback_dict):
repr_traceback_dict["reprentries"] = [ repr_traceback_dict['reprentries'] = [
deserialize_repr_entry(x) for x in repr_traceback_dict["reprentries"] deserialize_repr_entry(x) for x in repr_traceback_dict['reprentries']
] ]
return ReprTraceback(**repr_traceback_dict) return ReprTraceback(**repr_traceback_dict)
def deserialize_repr_crash(repr_crash_dict: Optional[Dict[str, Any]]): def deserialize_repr_crash(repr_crash_dict: dict[str, Any] | None):
if repr_crash_dict is not None: if repr_crash_dict is not None:
return ReprFileLocation(**repr_crash_dict) return ReprFileLocation(**repr_crash_dict)
else: else:
return None return None
if ( if (
reportdict["longrepr"] reportdict['longrepr'] and
and "reprcrash" in reportdict["longrepr"] 'reprcrash' in reportdict['longrepr'] and
and "reprtraceback" in reportdict["longrepr"] 'reprtraceback' in reportdict['longrepr']
): ):
reprtraceback = deserialize_repr_traceback( reprtraceback = deserialize_repr_traceback(
reportdict["longrepr"]["reprtraceback"] reportdict['longrepr']['reprtraceback'],
) )
reprcrash = deserialize_repr_crash(reportdict["longrepr"]["reprcrash"]) reprcrash = deserialize_repr_crash(reportdict['longrepr']['reprcrash'])
if reportdict["longrepr"]["chain"]: if reportdict['longrepr']['chain']:
chain = [] chain = []
for repr_traceback_data, repr_crash_data, description in reportdict[ for repr_traceback_data, repr_crash_data, description in reportdict[
"longrepr" 'longrepr'
]["chain"]: ]['chain']:
chain.append( chain.append(
( (
deserialize_repr_traceback(repr_traceback_data), deserialize_repr_traceback(repr_traceback_data),
deserialize_repr_crash(repr_crash_data), deserialize_repr_crash(repr_crash_data),
description, description,
) ),
) )
exception_info: Union[ exception_info: (
ExceptionChainRepr, ReprExceptionInfo ExceptionChainRepr | ReprExceptionInfo
] = ExceptionChainRepr(chain) ) = ExceptionChainRepr(chain)
else: else:
exception_info = ReprExceptionInfo( exception_info = ReprExceptionInfo(
reprtraceback=reprtraceback, reprtraceback=reprtraceback,
reprcrash=reprcrash, reprcrash=reprcrash,
) )
for section in reportdict["longrepr"]["sections"]: for section in reportdict['longrepr']['sections']:
exception_info.addsection(*section) exception_info.addsection(*section)
reportdict["longrepr"] = exception_info reportdict['longrepr'] = exception_info
return reportdict return reportdict

View file

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Basic collect and runtest protocol implementations.""" """Basic collect and runtest protocol implementations."""
from __future__ import annotations
import bdb import bdb
import dataclasses import dataclasses
import os import os
@ -18,10 +20,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
from .reports import BaseReport
from .reports import CollectErrorRepr
from .reports import CollectReport
from .reports import TestReport
from _pytest import timing from _pytest import timing
from _pytest._code.code import ExceptionChainRepr from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo from _pytest._code.code import ExceptionInfo
@ -37,6 +35,11 @@ from _pytest.outcomes import OutcomeException
from _pytest.outcomes import Skipped from _pytest.outcomes import Skipped
from _pytest.outcomes import TEST_OUTCOME from _pytest.outcomes import TEST_OUTCOME
from .reports import BaseReport
from .reports import CollectErrorRepr
from .reports import CollectReport
from .reports import TestReport
if sys.version_info[:2] < (3, 11): if sys.version_info[:2] < (3, 11):
from exceptiongroup import BaseExceptionGroup from exceptiongroup import BaseExceptionGroup
@ -50,66 +53,66 @@ if TYPE_CHECKING:
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting", "Reporting", after="general") group = parser.getgroup('terminal reporting', 'Reporting', after='general')
group.addoption( group.addoption(
"--durations", '--durations',
action="store", action='store',
type=int, type=int,
default=None, default=None,
metavar="N", metavar='N',
help="Show N slowest setup/test durations (N=0 for all)", help='Show N slowest setup/test durations (N=0 for all)',
) )
group.addoption( group.addoption(
"--durations-min", '--durations-min',
action="store", action='store',
type=float, type=float,
default=0.005, default=0.005,
metavar="N", metavar='N',
help="Minimal duration in seconds for inclusion in slowest list. " help='Minimal duration in seconds for inclusion in slowest list. '
"Default: 0.005.", 'Default: 0.005.',
) )
def pytest_terminal_summary(terminalreporter: "TerminalReporter") -> None: def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
durations = terminalreporter.config.option.durations durations = terminalreporter.config.option.durations
durations_min = terminalreporter.config.option.durations_min durations_min = terminalreporter.config.option.durations_min
verbose = terminalreporter.config.getvalue("verbose") verbose = terminalreporter.config.getvalue('verbose')
if durations is None: if durations is None:
return return
tr = terminalreporter tr = terminalreporter
dlist = [] dlist = []
for replist in tr.stats.values(): for replist in tr.stats.values():
for rep in replist: for rep in replist:
if hasattr(rep, "duration"): if hasattr(rep, 'duration'):
dlist.append(rep) dlist.append(rep)
if not dlist: if not dlist:
return return
dlist.sort(key=lambda x: x.duration, reverse=True) # type: ignore[no-any-return] dlist.sort(key=lambda x: x.duration, reverse=True) # type: ignore[no-any-return]
if not durations: if not durations:
tr.write_sep("=", "slowest durations") tr.write_sep('=', 'slowest durations')
else: else:
tr.write_sep("=", "slowest %s durations" % durations) tr.write_sep('=', 'slowest %s durations' % durations)
dlist = dlist[:durations] dlist = dlist[:durations]
for i, rep in enumerate(dlist): for i, rep in enumerate(dlist):
if verbose < 2 and rep.duration < durations_min: if verbose < 2 and rep.duration < durations_min:
tr.write_line("") tr.write_line('')
tr.write_line( tr.write_line(
f"({len(dlist) - i} durations < {durations_min:g}s hidden. Use -vv to show these durations.)" f'({len(dlist) - i} durations < {durations_min:g}s hidden. Use -vv to show these durations.)',
) )
break break
tr.write_line(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}") tr.write_line(f'{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}')
def pytest_sessionstart(session: "Session") -> None: def pytest_sessionstart(session: Session) -> None:
session._setupstate = SetupState() session._setupstate = SetupState()
def pytest_sessionfinish(session: "Session") -> None: def pytest_sessionfinish(session: Session) -> None:
session._setupstate.teardown_exact(None) session._setupstate.teardown_exact(None)
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool: def pytest_runtest_protocol(item: Item, nextitem: Item | None) -> bool:
ihook = item.ihook ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location) ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
runtestprotocol(item, nextitem=nextitem) runtestprotocol(item, nextitem=nextitem)
@ -118,21 +121,21 @@ def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
def runtestprotocol( def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None item: Item, log: bool = True, nextitem: Item | None = None,
) -> List[TestReport]: ) -> list[TestReport]:
hasrequest = hasattr(item, "_request") hasrequest = hasattr(item, '_request')
if hasrequest and not item._request: # type: ignore[attr-defined] if hasrequest and not item._request: # type: ignore[attr-defined]
# This only happens if the item is re-run, as is done by # This only happens if the item is re-run, as is done by
# pytest-rerunfailures. # pytest-rerunfailures.
item._initrequest() # type: ignore[attr-defined] item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log) rep = call_and_report(item, 'setup', log)
reports = [rep] reports = [rep]
if rep.passed: if rep.passed:
if item.config.getoption("setupshow", False): if item.config.getoption('setupshow', False):
show_test_item(item) show_test_item(item)
if not item.config.getoption("setuponly", False): if not item.config.getoption('setuponly', False):
reports.append(call_and_report(item, "call", log)) reports.append(call_and_report(item, 'call', log))
reports.append(call_and_report(item, "teardown", log, nextitem=nextitem)) reports.append(call_and_report(item, 'teardown', log, nextitem=nextitem))
# After all teardown hooks have been called # After all teardown hooks have been called
# want funcargs and request info to go away. # want funcargs and request info to go away.
if hasrequest: if hasrequest:
@ -145,21 +148,21 @@ def show_test_item(item: Item) -> None:
"""Show test function, parameters and the fixtures of the test item.""" """Show test function, parameters and the fixtures of the test item."""
tw = item.config.get_terminal_writer() tw = item.config.get_terminal_writer()
tw.line() tw.line()
tw.write(" " * 8) tw.write(' ' * 8)
tw.write(item.nodeid) tw.write(item.nodeid)
used_fixtures = sorted(getattr(item, "fixturenames", [])) used_fixtures = sorted(getattr(item, 'fixturenames', []))
if used_fixtures: if used_fixtures:
tw.write(" (fixtures used: {})".format(", ".join(used_fixtures))) tw.write(' (fixtures used: {})'.format(', '.join(used_fixtures)))
tw.flush() tw.flush()
def pytest_runtest_setup(item: Item) -> None: def pytest_runtest_setup(item: Item) -> None:
_update_current_test_var(item, "setup") _update_current_test_var(item, 'setup')
item.session._setupstate.setup(item) item.session._setupstate.setup(item)
def pytest_runtest_call(item: Item) -> None: def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call") _update_current_test_var(item, 'call')
try: try:
del sys.last_type del sys.last_type
del sys.last_value del sys.last_value
@ -182,38 +185,38 @@ def pytest_runtest_call(item: Item) -> None:
raise e raise e
def pytest_runtest_teardown(item: Item, nextitem: Optional[Item]) -> None: def pytest_runtest_teardown(item: Item, nextitem: Item | None) -> None:
_update_current_test_var(item, "teardown") _update_current_test_var(item, 'teardown')
item.session._setupstate.teardown_exact(nextitem) item.session._setupstate.teardown_exact(nextitem)
_update_current_test_var(item, None) _update_current_test_var(item, None)
def _update_current_test_var( def _update_current_test_var(
item: Item, when: Optional[Literal["setup", "call", "teardown"]] item: Item, when: Literal['setup', 'call', 'teardown'] | None,
) -> None: ) -> None:
"""Update :envvar:`PYTEST_CURRENT_TEST` to reflect the current item and stage. """Update :envvar:`PYTEST_CURRENT_TEST` to reflect the current item and stage.
If ``when`` is None, delete ``PYTEST_CURRENT_TEST`` from the environment. If ``when`` is None, delete ``PYTEST_CURRENT_TEST`` from the environment.
""" """
var_name = "PYTEST_CURRENT_TEST" var_name = 'PYTEST_CURRENT_TEST'
if when: if when:
value = f"{item.nodeid} ({when})" value = f'{item.nodeid} ({when})'
# don't allow null bytes on environment variables (see #2644, #2957) # don't allow null bytes on environment variables (see #2644, #2957)
value = value.replace("\x00", "(null)") value = value.replace('\x00', '(null)')
os.environ[var_name] = value os.environ[var_name] = value
else: else:
os.environ.pop(var_name) os.environ.pop(var_name)
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]: def pytest_report_teststatus(report: BaseReport) -> tuple[str, str, str] | None:
if report.when in ("setup", "teardown"): if report.when in ('setup', 'teardown'):
if report.failed: if report.failed:
# category, shortletter, verbose-word # category, shortletter, verbose-word
return "error", "E", "ERROR" return 'error', 'E', 'ERROR'
elif report.skipped: elif report.skipped:
return "skipped", "s", "SKIPPED" return 'skipped', 's', 'SKIPPED'
else: else:
return "", "", "" return '', '', ''
return None return None
@ -222,22 +225,22 @@ def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str
def call_and_report( def call_and_report(
item: Item, when: Literal["setup", "call", "teardown"], log: bool = True, **kwds item: Item, when: Literal['setup', 'call', 'teardown'], log: bool = True, **kwds,
) -> TestReport: ) -> TestReport:
ihook = item.ihook ihook = item.ihook
if when == "setup": if when == 'setup':
runtest_hook: Callable[..., None] = ihook.pytest_runtest_setup runtest_hook: Callable[..., None] = ihook.pytest_runtest_setup
elif when == "call": elif when == 'call':
runtest_hook = ihook.pytest_runtest_call runtest_hook = ihook.pytest_runtest_call
elif when == "teardown": elif when == 'teardown':
runtest_hook = ihook.pytest_runtest_teardown runtest_hook = ihook.pytest_runtest_teardown
else: else:
assert False, f"Unhandled runtest hook case: {when}" assert False, f'Unhandled runtest hook case: {when}'
reraise: Tuple[Type[BaseException], ...] = (Exit,) reraise: tuple[type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False): if not item.config.getoption('usepdb', False):
reraise += (KeyboardInterrupt,) reraise += (KeyboardInterrupt,)
call = CallInfo.from_call( call = CallInfo.from_call(
lambda: runtest_hook(item=item, **kwds), when=when, reraise=reraise lambda: runtest_hook(item=item, **kwds), when=when, reraise=reraise,
) )
report: TestReport = ihook.pytest_runtest_makereport(item=item, call=call) report: TestReport = ihook.pytest_runtest_makereport(item=item, call=call)
if log: if log:
@ -247,13 +250,13 @@ def call_and_report(
return report return report
def check_interactive_exception(call: "CallInfo[object]", report: BaseReport) -> bool: def check_interactive_exception(call: CallInfo[object], report: BaseReport) -> bool:
"""Check whether the call raised an exception that should be reported as """Check whether the call raised an exception that should be reported as
interactive.""" interactive."""
if call.excinfo is None: if call.excinfo is None:
# Didn't raise. # Didn't raise.
return False return False
if hasattr(report, "wasxfail"): if hasattr(report, 'wasxfail'):
# Exception was expected. # Exception was expected.
return False return False
if isinstance(call.excinfo.value, (Skipped, bdb.BdbQuit)): if isinstance(call.excinfo.value, (Skipped, bdb.BdbQuit)):
@ -262,7 +265,7 @@ def check_interactive_exception(call: "CallInfo[object]", report: BaseReport) ->
return True return True
TResult = TypeVar("TResult", covariant=True) TResult = TypeVar('TResult', covariant=True)
@final @final
@ -270,9 +273,9 @@ TResult = TypeVar("TResult", covariant=True)
class CallInfo(Generic[TResult]): class CallInfo(Generic[TResult]):
"""Result/Exception info of a function invocation.""" """Result/Exception info of a function invocation."""
_result: Optional[TResult] _result: TResult | None
#: The captured exception of the call, if it raised. #: The captured exception of the call, if it raised.
excinfo: Optional[ExceptionInfo[BaseException]] excinfo: ExceptionInfo[BaseException] | None
#: The system time when the call started, in seconds since the epoch. #: The system time when the call started, in seconds since the epoch.
start: float start: float
#: The system time when the call ended, in seconds since the epoch. #: The system time when the call ended, in seconds since the epoch.
@ -280,16 +283,16 @@ class CallInfo(Generic[TResult]):
#: The call duration, in seconds. #: The call duration, in seconds.
duration: float duration: float
#: The context of invocation: "collect", "setup", "call" or "teardown". #: The context of invocation: "collect", "setup", "call" or "teardown".
when: Literal["collect", "setup", "call", "teardown"] when: Literal['collect', 'setup', 'call', 'teardown']
def __init__( def __init__(
self, self,
result: Optional[TResult], result: TResult | None,
excinfo: Optional[ExceptionInfo[BaseException]], excinfo: ExceptionInfo[BaseException] | None,
start: float, start: float,
stop: float, stop: float,
duration: float, duration: float,
when: Literal["collect", "setup", "call", "teardown"], when: Literal['collect', 'setup', 'call', 'teardown'],
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
@ -308,7 +311,7 @@ class CallInfo(Generic[TResult]):
Can only be accessed if excinfo is None. Can only be accessed if excinfo is None.
""" """
if self.excinfo is not None: if self.excinfo is not None:
raise AttributeError(f"{self!r} has no valid result") raise AttributeError(f'{self!r} has no valid result')
# The cast is safe because an exception wasn't raised, hence # The cast is safe because an exception wasn't raised, hence
# _result has the expected function return type (which may be # _result has the expected function return type (which may be
# None, that's why a cast and not an assert). # None, that's why a cast and not an assert).
@ -318,11 +321,11 @@ class CallInfo(Generic[TResult]):
def from_call( def from_call(
cls, cls,
func: Callable[[], TResult], func: Callable[[], TResult],
when: Literal["collect", "setup", "call", "teardown"], when: Literal['collect', 'setup', 'call', 'teardown'],
reraise: Optional[ reraise: None | (
Union[Type[BaseException], Tuple[Type[BaseException], ...]] type[BaseException] | tuple[type[BaseException], ...]
] = None, ) = None,
) -> "CallInfo[TResult]": ) -> CallInfo[TResult]:
"""Call func, wrapping the result in a CallInfo. """Call func, wrapping the result in a CallInfo.
:param func: :param func:
@ -337,7 +340,7 @@ class CallInfo(Generic[TResult]):
start = timing.time() start = timing.time()
precise_start = timing.perf_counter() precise_start = timing.perf_counter()
try: try:
result: Optional[TResult] = func() result: TResult | None = func()
except BaseException: except BaseException:
excinfo = ExceptionInfo.from_current() excinfo = ExceptionInfo.from_current()
if reraise is not None and isinstance(excinfo.value, reraise): if reraise is not None and isinstance(excinfo.value, reraise):
@ -359,8 +362,8 @@ class CallInfo(Generic[TResult]):
def __repr__(self) -> str: def __repr__(self) -> str:
if self.excinfo is None: if self.excinfo is None:
return f"<CallInfo when={self.when!r} result: {self._result!r}>" return f'<CallInfo when={self.when!r} result: {self._result!r}>'
return f"<CallInfo when={self.when!r} excinfo={self.excinfo!r}>" return f'<CallInfo when={self.when!r} excinfo={self.excinfo!r}>'
def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport: def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport:
@ -368,7 +371,7 @@ def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport:
def pytest_make_collect_report(collector: Collector) -> CollectReport: def pytest_make_collect_report(collector: Collector) -> CollectReport:
def collect() -> List[Union[Item, Collector]]: def collect() -> list[Item | Collector]:
# Before collecting, if this is a Directory, load the conftests. # Before collecting, if this is a Directory, load the conftests.
# If a conftest import fails to load, it is considered a collection # If a conftest import fails to load, it is considered a collection
# error of the Directory collector. This is why it's done inside of the # error of the Directory collector. This is why it's done inside of the
@ -378,36 +381,36 @@ def pytest_make_collect_report(collector: Collector) -> CollectReport:
if isinstance(collector, Directory): if isinstance(collector, Directory):
collector.config.pluginmanager._loadconftestmodules( collector.config.pluginmanager._loadconftestmodules(
collector.path, collector.path,
collector.config.getoption("importmode"), collector.config.getoption('importmode'),
rootpath=collector.config.rootpath, rootpath=collector.config.rootpath,
consider_namespace_packages=collector.config.getini( consider_namespace_packages=collector.config.getini(
"consider_namespace_packages" 'consider_namespace_packages',
), ),
) )
return list(collector.collect()) return list(collector.collect())
call = CallInfo.from_call(collect, "collect") call = CallInfo.from_call(collect, 'collect')
longrepr: Union[None, Tuple[str, int, str], str, TerminalRepr] = None longrepr: None | tuple[str, int, str] | str | TerminalRepr = None
if not call.excinfo: if not call.excinfo:
outcome: Literal["passed", "skipped", "failed"] = "passed" outcome: Literal['passed', 'skipped', 'failed'] = 'passed'
else: else:
skip_exceptions = [Skipped] skip_exceptions = [Skipped]
unittest = sys.modules.get("unittest") unittest = sys.modules.get('unittest')
if unittest is not None: if unittest is not None:
# Type ignored because unittest is loaded dynamically. # Type ignored because unittest is loaded dynamically.
skip_exceptions.append(unittest.SkipTest) # type: ignore skip_exceptions.append(unittest.SkipTest) # type: ignore
if isinstance(call.excinfo.value, tuple(skip_exceptions)): if isinstance(call.excinfo.value, tuple(skip_exceptions)):
outcome = "skipped" outcome = 'skipped'
r_ = collector._repr_failure_py(call.excinfo, "line") r_ = collector._repr_failure_py(call.excinfo, 'line')
assert isinstance(r_, ExceptionChainRepr), repr(r_) assert isinstance(r_, ExceptionChainRepr), repr(r_)
r = r_.reprcrash r = r_.reprcrash
assert r assert r
longrepr = (str(r.path), r.lineno, r.message) longrepr = (str(r.path), r.lineno, r.message)
else: else:
outcome = "failed" outcome = 'failed'
errorinfo = collector.repr_failure(call.excinfo) errorinfo = collector.repr_failure(call.excinfo)
if not hasattr(errorinfo, "toterminal"): if not hasattr(errorinfo, 'toterminal'):
assert isinstance(errorinfo, str) assert isinstance(errorinfo, str)
errorinfo = CollectErrorRepr(errorinfo) errorinfo = CollectErrorRepr(errorinfo)
longrepr = errorinfo longrepr = errorinfo
@ -483,13 +486,13 @@ class SetupState:
def __init__(self) -> None: def __init__(self) -> None:
# The stack is in the dict insertion order. # The stack is in the dict insertion order.
self.stack: Dict[ self.stack: dict[
Node, Node,
Tuple[ tuple[
# Node's finalizers. # Node's finalizers.
List[Callable[[], object]], list[Callable[[], object]],
# Node's exception, if its setup raised. # Node's exception, if its setup raised.
Optional[Union[OutcomeException, Exception]], OutcomeException | Exception | None,
], ],
] = {} ] = {}
@ -500,11 +503,11 @@ class SetupState:
# If a collector fails its setup, fail its entire subtree of items. # If a collector fails its setup, fail its entire subtree of items.
# The setup is not retried for each item - the same exception is used. # The setup is not retried for each item - the same exception is used.
for col, (finalizers, exc) in self.stack.items(): for col, (finalizers, exc) in self.stack.items():
assert col in needed_collectors, "previous item was not torn down properly" assert col in needed_collectors, 'previous item was not torn down properly'
if exc: if exc:
raise exc raise exc
for col in needed_collectors[len(self.stack) :]: for col in needed_collectors[len(self.stack):]:
assert col not in self.stack assert col not in self.stack
# Push onto the stack. # Push onto the stack.
self.stack[col] = ([col.teardown], None) self.stack[col] = ([col.teardown], None)
@ -524,7 +527,7 @@ class SetupState:
assert node in self.stack, (node, self.stack) assert node in self.stack, (node, self.stack)
self.stack[node][0].append(finalizer) self.stack[node][0].append(finalizer)
def teardown_exact(self, nextitem: Optional[Item]) -> None: def teardown_exact(self, nextitem: Item | None) -> None:
"""Teardown the current stack up until reaching nodes that nextitem """Teardown the current stack up until reaching nodes that nextitem
also descends from. also descends from.
@ -532,7 +535,7 @@ class SetupState:
stack is torn down. stack is torn down.
""" """
needed_collectors = nextitem and nextitem.listchain() or [] needed_collectors = nextitem and nextitem.listchain() or []
exceptions: List[BaseException] = [] exceptions: list[BaseException] = []
while self.stack: while self.stack:
if list(self.stack.keys()) == needed_collectors[: len(self.stack)]: if list(self.stack.keys()) == needed_collectors[: len(self.stack)]:
break break
@ -548,13 +551,13 @@ class SetupState:
if len(these_exceptions) == 1: if len(these_exceptions) == 1:
exceptions.extend(these_exceptions) exceptions.extend(these_exceptions)
elif these_exceptions: elif these_exceptions:
msg = f"errors while tearing down {node!r}" msg = f'errors while tearing down {node!r}'
exceptions.append(BaseExceptionGroup(msg, these_exceptions[::-1])) exceptions.append(BaseExceptionGroup(msg, these_exceptions[::-1]))
if len(exceptions) == 1: if len(exceptions) == 1:
raise exceptions[0] raise exceptions[0]
elif exceptions: elif exceptions:
raise BaseExceptionGroup("errors during test teardown", exceptions[::-1]) raise BaseExceptionGroup('errors during test teardown', exceptions[::-1])
if nextitem is None: if nextitem is None:
assert not self.stack assert not self.stack
@ -563,7 +566,7 @@ def collect_one_node(collector: Collector) -> CollectReport:
ihook = collector.ihook ihook = collector.ihook
ihook.pytest_collectstart(collector=collector) ihook.pytest_collectstart(collector=collector)
rep: CollectReport = ihook.pytest_make_collect_report(collector=collector) rep: CollectReport = ihook.pytest_make_collect_report(collector=collector)
call = rep.__dict__.pop("call", None) call = rep.__dict__.pop('call', None)
if call and check_interactive_exception(call, rep): if call and check_interactive_exception(call, rep):
ihook.pytest_exception_interact(node=collector, call=call, report=rep) ihook.pytest_exception_interact(node=collector, call=call, report=rep)
return rep return rep

View file

@ -7,6 +7,7 @@ would cause circular references.
Also this makes the module light to import, as it should. Also this makes the module light to import, as it should.
""" """
from __future__ import annotations
from enum import Enum from enum import Enum
from functools import total_ordering from functools import total_ordering
@ -14,7 +15,7 @@ from typing import Literal
from typing import Optional from typing import Optional
_ScopeName = Literal["session", "package", "module", "class", "function"] _ScopeName = Literal['session', 'package', 'module', 'class', 'function']
@total_ordering @total_ordering
@ -32,35 +33,35 @@ class Scope(Enum):
""" """
# Scopes need to be listed from lower to higher. # Scopes need to be listed from lower to higher.
Function: _ScopeName = "function" Function: _ScopeName = 'function'
Class: _ScopeName = "class" Class: _ScopeName = 'class'
Module: _ScopeName = "module" Module: _ScopeName = 'module'
Package: _ScopeName = "package" Package: _ScopeName = 'package'
Session: _ScopeName = "session" Session: _ScopeName = 'session'
def next_lower(self) -> "Scope": def next_lower(self) -> Scope:
"""Return the next lower scope.""" """Return the next lower scope."""
index = _SCOPE_INDICES[self] index = _SCOPE_INDICES[self]
if index == 0: if index == 0:
raise ValueError(f"{self} is the lower-most scope") raise ValueError(f'{self} is the lower-most scope')
return _ALL_SCOPES[index - 1] return _ALL_SCOPES[index - 1]
def next_higher(self) -> "Scope": def next_higher(self) -> Scope:
"""Return the next higher scope.""" """Return the next higher scope."""
index = _SCOPE_INDICES[self] index = _SCOPE_INDICES[self]
if index == len(_SCOPE_INDICES) - 1: if index == len(_SCOPE_INDICES) - 1:
raise ValueError(f"{self} is the upper-most scope") raise ValueError(f'{self} is the upper-most scope')
return _ALL_SCOPES[index + 1] return _ALL_SCOPES[index + 1]
def __lt__(self, other: "Scope") -> bool: def __lt__(self, other: Scope) -> bool:
self_index = _SCOPE_INDICES[self] self_index = _SCOPE_INDICES[self]
other_index = _SCOPE_INDICES[other] other_index = _SCOPE_INDICES[other]
return self_index < other_index return self_index < other_index
@classmethod @classmethod
def from_user( def from_user(
cls, scope_name: _ScopeName, descr: str, where: Optional[str] = None cls, scope_name: _ScopeName, descr: str, where: str | None = None,
) -> "Scope": ) -> Scope:
""" """
Given a scope name from the user, return the equivalent Scope enum. Should be used Given a scope name from the user, return the equivalent Scope enum. Should be used
whenever we want to convert a user provided scope name to its enum object. whenever we want to convert a user provided scope name to its enum object.
@ -75,7 +76,7 @@ class Scope(Enum):
except ValueError: except ValueError:
fail( fail(
"{} {}got an unexpected scope value '{}'".format( "{} {}got an unexpected scope value '{}'".format(
descr, f"from {where} " if where else "", scope_name descr, f'from {where} ' if where else '', scope_name,
), ),
pytrace=False, pytrace=False,
) )

View file

@ -1,7 +1,10 @@
from __future__ import annotations
from typing import Generator from typing import Generator
from typing import Optional from typing import Optional
from typing import Union from typing import Union
import pytest
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import ExitCode from _pytest.config import ExitCode
@ -9,34 +12,33 @@ from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest from _pytest.fixtures import SubRequest
from _pytest.scope import Scope from _pytest.scope import Scope
import pytest
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig") group = parser.getgroup('debugconfig')
group.addoption( group.addoption(
"--setuponly", '--setuponly',
"--setup-only", '--setup-only',
action="store_true", action='store_true',
help="Only setup fixtures, do not execute tests", help='Only setup fixtures, do not execute tests',
) )
group.addoption( group.addoption(
"--setupshow", '--setupshow',
"--setup-show", '--setup-show',
action="store_true", action='store_true',
help="Show setup of fixtures while executing tests", help='Show setup of fixtures while executing tests',
) )
@pytest.hookimpl(wrapper=True) @pytest.hookimpl(wrapper=True)
def pytest_fixture_setup( def pytest_fixture_setup(
fixturedef: FixtureDef[object], request: SubRequest fixturedef: FixtureDef[object], request: SubRequest,
) -> Generator[None, object, object]: ) -> Generator[None, object, object]:
try: try:
return (yield) return (yield)
finally: finally:
if request.config.option.setupshow: if request.config.option.setupshow:
if hasattr(request, "param"): if hasattr(request, 'param'):
# Save the fixture parameter so ._show_fixture_action() can # Save the fixture parameter so ._show_fixture_action() can
# display it now and during the teardown (in .finish()). # display it now and during the teardown (in .finish()).
if fixturedef.ids: if fixturedef.ids:
@ -47,24 +49,24 @@ def pytest_fixture_setup(
else: else:
param = request.param param = request.param
fixturedef.cached_param = param # type: ignore[attr-defined] fixturedef.cached_param = param # type: ignore[attr-defined]
_show_fixture_action(fixturedef, request.config, "SETUP") _show_fixture_action(fixturedef, request.config, 'SETUP')
def pytest_fixture_post_finalizer( def pytest_fixture_post_finalizer(
fixturedef: FixtureDef[object], request: SubRequest fixturedef: FixtureDef[object], request: SubRequest,
) -> None: ) -> None:
if fixturedef.cached_result is not None: if fixturedef.cached_result is not None:
config = request.config config = request.config
if config.option.setupshow: if config.option.setupshow:
_show_fixture_action(fixturedef, request.config, "TEARDOWN") _show_fixture_action(fixturedef, request.config, 'TEARDOWN')
if hasattr(fixturedef, "cached_param"): if hasattr(fixturedef, 'cached_param'):
del fixturedef.cached_param # type: ignore[attr-defined] del fixturedef.cached_param # type: ignore[attr-defined]
def _show_fixture_action( def _show_fixture_action(
fixturedef: FixtureDef[object], config: Config, msg: str fixturedef: FixtureDef[object], config: Config, msg: str,
) -> None: ) -> None:
capman = config.pluginmanager.getplugin("capturemanager") capman = config.pluginmanager.getplugin('capturemanager')
if capman: if capman:
capman.suspend_global_capture() capman.suspend_global_capture()
@ -72,22 +74,22 @@ def _show_fixture_action(
tw.line() tw.line()
# Use smaller indentation the higher the scope: Session = 0, Package = 1, etc. # Use smaller indentation the higher the scope: Session = 0, Package = 1, etc.
scope_indent = list(reversed(Scope)).index(fixturedef._scope) scope_indent = list(reversed(Scope)).index(fixturedef._scope)
tw.write(" " * 2 * scope_indent) tw.write(' ' * 2 * scope_indent)
tw.write( tw.write(
"{step} {scope} {fixture}".format( # noqa: UP032 (Readability) '{step} {scope} {fixture}'.format( # noqa: UP032 (Readability)
step=msg.ljust(8), # align the output to TEARDOWN step=msg.ljust(8), # align the output to TEARDOWN
scope=fixturedef.scope[0].upper(), scope=fixturedef.scope[0].upper(),
fixture=fixturedef.argname, fixture=fixturedef.argname,
) ),
) )
if msg == "SETUP": if msg == 'SETUP':
deps = sorted(arg for arg in fixturedef.argnames if arg != "request") deps = sorted(arg for arg in fixturedef.argnames if arg != 'request')
if deps: if deps:
tw.write(" (fixtures used: {})".format(", ".join(deps))) tw.write(' (fixtures used: {})'.format(', '.join(deps)))
if hasattr(fixturedef, "cached_param"): if hasattr(fixturedef, 'cached_param'):
tw.write(f"[{saferepr(fixturedef.cached_param, maxsize=42)}]") # type: ignore[attr-defined] tw.write(f'[{saferepr(fixturedef.cached_param, maxsize=42)}]') # type: ignore[attr-defined]
tw.flush() tw.flush()
@ -96,7 +98,7 @@ def _show_fixture_action(
@pytest.hookimpl(tryfirst=True) @pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.setuponly: if config.option.setuponly:
config.option.setupshow = True config.option.setupshow = True
return None return None

View file

@ -1,29 +1,31 @@
from __future__ import annotations
from typing import Optional from typing import Optional
from typing import Union from typing import Union
import pytest
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import ExitCode from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest from _pytest.fixtures import SubRequest
import pytest
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig") group = parser.getgroup('debugconfig')
group.addoption( group.addoption(
"--setupplan", '--setupplan',
"--setup-plan", '--setup-plan',
action="store_true", action='store_true',
help="Show what fixtures and tests would be executed but " help='Show what fixtures and tests would be executed but '
"don't execute anything", "don't execute anything",
) )
@pytest.hookimpl(tryfirst=True) @pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup( def pytest_fixture_setup(
fixturedef: FixtureDef[object], request: SubRequest fixturedef: FixtureDef[object], request: SubRequest,
) -> Optional[object]: ) -> object | None:
# Will return a dummy fixture if the setuponly option is provided. # Will return a dummy fixture if the setuponly option is provided.
if request.config.option.setupplan: if request.config.option.setupplan:
my_cache_key = fixturedef.cache_key(request) my_cache_key = fixturedef.cache_key(request)
@ -33,7 +35,7 @@ def pytest_fixture_setup(
@pytest.hookimpl(tryfirst=True) @pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.setupplan: if config.option.setupplan:
config.option.setuponly = True config.option.setuponly = True
config.option.setupshow = True config.option.setupshow = True

View file

@ -1,11 +1,13 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Support for skip/xfail functions and markers.""" """Support for skip/xfail functions and markers."""
from collections.abc import Mapping from __future__ import annotations
import dataclasses import dataclasses
import os import os
import platform import platform
import sys import sys
import traceback import traceback
from collections.abc import Mapping
from typing import Generator from typing import Generator
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
@ -26,21 +28,21 @@ from _pytest.stash import StashKey
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general") group = parser.getgroup('general')
group.addoption( group.addoption(
"--runxfail", '--runxfail',
action="store_true", action='store_true',
dest="runxfail", dest='runxfail',
default=False, default=False,
help="Report the results of xfail tests as if they were not marked", help='Report the results of xfail tests as if they were not marked',
) )
parser.addini( parser.addini(
"xfail_strict", 'xfail_strict',
"Default for the strict parameter of xfail " 'Default for the strict parameter of xfail '
"markers when not given explicitly (default: False)", 'markers when not given explicitly (default: False)',
default=False, default=False,
type="bool", type='bool',
) )
@ -50,40 +52,40 @@ def pytest_configure(config: Config) -> None:
import pytest import pytest
old = pytest.xfail old = pytest.xfail
config.add_cleanup(lambda: setattr(pytest, "xfail", old)) config.add_cleanup(lambda: setattr(pytest, 'xfail', old))
def nop(*args, **kwargs): def nop(*args, **kwargs):
pass pass
nop.Exception = xfail.Exception # type: ignore[attr-defined] nop.Exception = xfail.Exception # type: ignore[attr-defined]
setattr(pytest, "xfail", nop) setattr(pytest, 'xfail', nop)
config.addinivalue_line( config.addinivalue_line(
"markers", 'markers',
"skip(reason=None): skip the given test function with an optional reason. " 'skip(reason=None): skip the given test function with an optional reason. '
'Example: skip(reason="no way of currently testing this") skips the ' 'Example: skip(reason="no way of currently testing this") skips the '
"test.", 'test.',
) )
config.addinivalue_line( config.addinivalue_line(
"markers", 'markers',
"skipif(condition, ..., *, reason=...): " 'skipif(condition, ..., *, reason=...): '
"skip the given test function if any of the conditions evaluate to True. " 'skip the given test function if any of the conditions evaluate to True. '
"Example: skipif(sys.platform == 'win32') skips the test if we are on the win32 platform. " "Example: skipif(sys.platform == 'win32') skips the test if we are on the win32 platform. "
"See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-skipif", 'See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-skipif',
) )
config.addinivalue_line( config.addinivalue_line(
"markers", 'markers',
"xfail(condition, ..., *, reason=..., run=True, raises=None, strict=xfail_strict): " 'xfail(condition, ..., *, reason=..., run=True, raises=None, strict=xfail_strict): '
"mark the test function as an expected failure if any of the conditions " 'mark the test function as an expected failure if any of the conditions '
"evaluate to True. Optionally specify a reason for better reporting " 'evaluate to True. Optionally specify a reason for better reporting '
"and run=False if you don't even want to execute the test function. " "and run=False if you don't even want to execute the test function. "
"If only specific exception(s) are expected, you can list them in " 'If only specific exception(s) are expected, you can list them in '
"raises, and if the test fails in other ways, it will be reported as " 'raises, and if the test fails in other ways, it will be reported as '
"a true failure. See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-xfail", 'a true failure. See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-xfail',
) )
def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool, str]: def evaluate_condition(item: Item, mark: Mark, condition: object) -> tuple[bool, str]:
"""Evaluate a single skipif/xfail condition. """Evaluate a single skipif/xfail condition.
If an old-style string condition is given, it is eval()'d, otherwise the If an old-style string condition is given, it is eval()'d, otherwise the
@ -95,40 +97,40 @@ def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool,
# String condition. # String condition.
if isinstance(condition, str): if isinstance(condition, str):
globals_ = { globals_ = {
"os": os, 'os': os,
"sys": sys, 'sys': sys,
"platform": platform, 'platform': platform,
"config": item.config, 'config': item.config,
} }
for dictionary in reversed( for dictionary in reversed(
item.ihook.pytest_markeval_namespace(config=item.config) item.ihook.pytest_markeval_namespace(config=item.config),
): ):
if not isinstance(dictionary, Mapping): if not isinstance(dictionary, Mapping):
raise ValueError( raise ValueError(
f"pytest_markeval_namespace() needs to return a dict, got {dictionary!r}" f'pytest_markeval_namespace() needs to return a dict, got {dictionary!r}',
) )
globals_.update(dictionary) globals_.update(dictionary)
if hasattr(item, "obj"): if hasattr(item, 'obj'):
globals_.update(item.obj.__globals__) # type: ignore[attr-defined] globals_.update(item.obj.__globals__) # type: ignore[attr-defined]
try: try:
filename = f"<{mark.name} condition>" filename = f'<{mark.name} condition>'
condition_code = compile(condition, filename, "eval") condition_code = compile(condition, filename, 'eval')
result = eval(condition_code, globals_) result = eval(condition_code, globals_)
except SyntaxError as exc: except SyntaxError as exc:
msglines = [ msglines = [
"Error evaluating %r condition" % mark.name, 'Error evaluating %r condition' % mark.name,
" " + condition, ' ' + condition,
" " + " " * (exc.offset or 0) + "^", ' ' + ' ' * (exc.offset or 0) + '^',
"SyntaxError: invalid syntax", 'SyntaxError: invalid syntax',
] ]
fail("\n".join(msglines), pytrace=False) fail('\n'.join(msglines), pytrace=False)
except Exception as exc: except Exception as exc:
msglines = [ msglines = [
"Error evaluating %r condition" % mark.name, 'Error evaluating %r condition' % mark.name,
" " + condition, ' ' + condition,
*traceback.format_exception_only(type(exc), exc), *traceback.format_exception_only(type(exc), exc),
] ]
fail("\n".join(msglines), pytrace=False) fail('\n'.join(msglines), pytrace=False)
# Boolean condition. # Boolean condition.
else: else:
@ -136,20 +138,20 @@ def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool,
result = bool(condition) result = bool(condition)
except Exception as exc: except Exception as exc:
msglines = [ msglines = [
"Error evaluating %r condition as a boolean" % mark.name, 'Error evaluating %r condition as a boolean' % mark.name,
*traceback.format_exception_only(type(exc), exc), *traceback.format_exception_only(type(exc), exc),
] ]
fail("\n".join(msglines), pytrace=False) fail('\n'.join(msglines), pytrace=False)
reason = mark.kwargs.get("reason", None) reason = mark.kwargs.get('reason', None)
if reason is None: if reason is None:
if isinstance(condition, str): if isinstance(condition, str):
reason = "condition: " + condition reason = 'condition: ' + condition
else: else:
# XXX better be checked at collection time # XXX better be checked at collection time
msg = ( msg = (
"Error evaluating %r: " % mark.name 'Error evaluating %r: ' % mark.name +
+ "you need to specify reason=STRING when using booleans as conditions." 'you need to specify reason=STRING when using booleans as conditions.'
) )
fail(msg, pytrace=False) fail(msg, pytrace=False)
@ -160,20 +162,20 @@ def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool,
class Skip: class Skip:
"""The result of evaluate_skip_marks().""" """The result of evaluate_skip_marks()."""
reason: str = "unconditional skip" reason: str = 'unconditional skip'
def evaluate_skip_marks(item: Item) -> Optional[Skip]: def evaluate_skip_marks(item: Item) -> Skip | None:
"""Evaluate skip and skipif marks on item, returning Skip if triggered.""" """Evaluate skip and skipif marks on item, returning Skip if triggered."""
for mark in item.iter_markers(name="skipif"): for mark in item.iter_markers(name='skipif'):
if "condition" not in mark.kwargs: if 'condition' not in mark.kwargs:
conditions = mark.args conditions = mark.args
else: else:
conditions = (mark.kwargs["condition"],) conditions = (mark.kwargs['condition'],)
# Unconditional. # Unconditional.
if not conditions: if not conditions:
reason = mark.kwargs.get("reason", "") reason = mark.kwargs.get('reason', '')
return Skip(reason) return Skip(reason)
# If any of the conditions are true. # If any of the conditions are true.
@ -182,11 +184,11 @@ def evaluate_skip_marks(item: Item) -> Optional[Skip]:
if result: if result:
return Skip(reason) return Skip(reason)
for mark in item.iter_markers(name="skip"): for mark in item.iter_markers(name='skip'):
try: try:
return Skip(*mark.args, **mark.kwargs) return Skip(*mark.args, **mark.kwargs)
except TypeError as e: except TypeError as e:
raise TypeError(str(e) + " - maybe you meant pytest.mark.skipif?") from None raise TypeError(str(e) + ' - maybe you meant pytest.mark.skipif?') from None
return None return None
@ -195,28 +197,28 @@ def evaluate_skip_marks(item: Item) -> Optional[Skip]:
class Xfail: class Xfail:
"""The result of evaluate_xfail_marks().""" """The result of evaluate_xfail_marks()."""
__slots__ = ("reason", "run", "strict", "raises") __slots__ = ('reason', 'run', 'strict', 'raises')
reason: str reason: str
run: bool run: bool
strict: bool strict: bool
raises: Optional[Tuple[Type[BaseException], ...]] raises: tuple[type[BaseException], ...] | None
def evaluate_xfail_marks(item: Item) -> Optional[Xfail]: def evaluate_xfail_marks(item: Item) -> Xfail | None:
"""Evaluate xfail marks on item, returning Xfail if triggered.""" """Evaluate xfail marks on item, returning Xfail if triggered."""
for mark in item.iter_markers(name="xfail"): for mark in item.iter_markers(name='xfail'):
run = mark.kwargs.get("run", True) run = mark.kwargs.get('run', True)
strict = mark.kwargs.get("strict", item.config.getini("xfail_strict")) strict = mark.kwargs.get('strict', item.config.getini('xfail_strict'))
raises = mark.kwargs.get("raises", None) raises = mark.kwargs.get('raises', None)
if "condition" not in mark.kwargs: if 'condition' not in mark.kwargs:
conditions = mark.args conditions = mark.args
else: else:
conditions = (mark.kwargs["condition"],) conditions = (mark.kwargs['condition'],)
# Unconditional. # Unconditional.
if not conditions: if not conditions:
reason = mark.kwargs.get("reason", "") reason = mark.kwargs.get('reason', '')
return Xfail(reason, run, strict, raises) return Xfail(reason, run, strict, raises)
# If any of the conditions are true. # If any of the conditions are true.
@ -240,7 +242,7 @@ def pytest_runtest_setup(item: Item) -> None:
item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item) item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item)
if xfailed and not item.config.option.runxfail and not xfailed.run: if xfailed and not item.config.option.runxfail and not xfailed.run:
xfail("[NOTRUN] " + xfailed.reason) xfail('[NOTRUN] ' + xfailed.reason)
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
@ -250,7 +252,7 @@ def pytest_runtest_call(item: Item) -> Generator[None, None, None]:
item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item) item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item)
if xfailed and not item.config.option.runxfail and not xfailed.run: if xfailed and not item.config.option.runxfail and not xfailed.run:
xfail("[NOTRUN] " + xfailed.reason) xfail('[NOTRUN] ' + xfailed.reason)
try: try:
return (yield) return (yield)
@ -263,7 +265,7 @@ def pytest_runtest_call(item: Item) -> Generator[None, None, None]:
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_runtest_makereport( def pytest_runtest_makereport(
item: Item, call: CallInfo[None] item: Item, call: CallInfo[None],
) -> Generator[None, TestReport, TestReport]: ) -> Generator[None, TestReport, TestReport]:
rep = yield rep = yield
xfailed = item.stash.get(xfailed_key, None) xfailed = item.stash.get(xfailed_key, None)
@ -271,30 +273,30 @@ def pytest_runtest_makereport(
pass # don't interfere pass # don't interfere
elif call.excinfo and isinstance(call.excinfo.value, xfail.Exception): elif call.excinfo and isinstance(call.excinfo.value, xfail.Exception):
assert call.excinfo.value.msg is not None assert call.excinfo.value.msg is not None
rep.wasxfail = "reason: " + call.excinfo.value.msg rep.wasxfail = 'reason: ' + call.excinfo.value.msg
rep.outcome = "skipped" rep.outcome = 'skipped'
elif not rep.skipped and xfailed: elif not rep.skipped and xfailed:
if call.excinfo: if call.excinfo:
raises = xfailed.raises raises = xfailed.raises
if raises is not None and not isinstance(call.excinfo.value, raises): if raises is not None and not isinstance(call.excinfo.value, raises):
rep.outcome = "failed" rep.outcome = 'failed'
else: else:
rep.outcome = "skipped" rep.outcome = 'skipped'
rep.wasxfail = xfailed.reason rep.wasxfail = xfailed.reason
elif call.when == "call": elif call.when == 'call':
if xfailed.strict: if xfailed.strict:
rep.outcome = "failed" rep.outcome = 'failed'
rep.longrepr = "[XPASS(strict)] " + xfailed.reason rep.longrepr = '[XPASS(strict)] ' + xfailed.reason
else: else:
rep.outcome = "passed" rep.outcome = 'passed'
rep.wasxfail = xfailed.reason rep.wasxfail = xfailed.reason
return rep return rep
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]: def pytest_report_teststatus(report: BaseReport) -> tuple[str, str, str] | None:
if hasattr(report, "wasxfail"): if hasattr(report, 'wasxfail'):
if report.skipped: if report.skipped:
return "xfailed", "x", "XFAIL" return 'xfailed', 'x', 'XFAIL'
elif report.passed: elif report.passed:
return "xpassed", "X", "XPASS" return 'xpassed', 'X', 'XPASS'
return None return None

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Any from typing import Any
from typing import cast from typing import cast
from typing import Dict from typing import Dict
@ -6,11 +8,11 @@ from typing import TypeVar
from typing import Union from typing import Union
__all__ = ["Stash", "StashKey"] __all__ = ['Stash', 'StashKey']
T = TypeVar("T") T = TypeVar('T')
D = TypeVar("D") D = TypeVar('D')
class StashKey(Generic[T]): class StashKey(Generic[T]):
@ -63,10 +65,10 @@ class Stash:
some_bool = stash[some_bool_key] some_bool = stash[some_bool_key]
""" """
__slots__ = ("_storage",) __slots__ = ('_storage',)
def __init__(self) -> None: def __init__(self) -> None:
self._storage: Dict[StashKey[Any], object] = {} self._storage: dict[StashKey[Any], object] = {}
def __setitem__(self, key: StashKey[T], value: T) -> None: def __setitem__(self, key: StashKey[T], value: T) -> None:
"""Set a value for key.""" """Set a value for key."""
@ -79,7 +81,7 @@ class Stash:
""" """
return cast(T, self._storage[key]) return cast(T, self._storage[key])
def get(self, key: StashKey[T], default: D) -> Union[T, D]: def get(self, key: StashKey[T], default: D) -> T | D:
"""Get the value for key, or return default if the key wasn't set """Get the value for key, or return default if the key wasn't set
before.""" before."""
try: try:

View file

@ -1,39 +1,41 @@
from __future__ import annotations
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import pytest
from _pytest import nodes from _pytest import nodes
from _pytest.config import Config from _pytest.config import Config
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.main import Session from _pytest.main import Session
from _pytest.reports import TestReport from _pytest.reports import TestReport
import pytest
if TYPE_CHECKING: if TYPE_CHECKING:
from _pytest.cacheprovider import Cache from _pytest.cacheprovider import Cache
STEPWISE_CACHE_DIR = "cache/stepwise" STEPWISE_CACHE_DIR = 'cache/stepwise'
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general") group = parser.getgroup('general')
group.addoption( group.addoption(
"--sw", '--sw',
"--stepwise", '--stepwise',
action="store_true", action='store_true',
default=False, default=False,
dest="stepwise", dest='stepwise',
help="Exit on test failure and continue from last failing test next time", help='Exit on test failure and continue from last failing test next time',
) )
group.addoption( group.addoption(
"--sw-skip", '--sw-skip',
"--stepwise-skip", '--stepwise-skip',
action="store_true", action='store_true',
default=False, default=False,
dest="stepwise_skip", dest='stepwise_skip',
help="Ignore the first failing test but stop on the next failing test. " help='Ignore the first failing test but stop on the next failing test. '
"Implicitly enables --stepwise.", 'Implicitly enables --stepwise.',
) )
@ -42,14 +44,14 @@ def pytest_configure(config: Config) -> None:
if config.option.stepwise_skip: if config.option.stepwise_skip:
# allow --stepwise-skip to work on it's own merits. # allow --stepwise-skip to work on it's own merits.
config.option.stepwise = True config.option.stepwise = True
if config.getoption("stepwise"): if config.getoption('stepwise'):
config.pluginmanager.register(StepwisePlugin(config), "stepwiseplugin") config.pluginmanager.register(StepwisePlugin(config), 'stepwiseplugin')
def pytest_sessionfinish(session: Session) -> None: def pytest_sessionfinish(session: Session) -> None:
if not session.config.getoption("stepwise"): if not session.config.getoption('stepwise'):
assert session.config.cache is not None assert session.config.cache is not None
if hasattr(session.config, "workerinput"): if hasattr(session.config, 'workerinput'):
# Do not update cache if this process is a xdist worker to prevent # Do not update cache if this process is a xdist worker to prevent
# race conditions (#10641). # race conditions (#10641).
return return
@ -60,21 +62,21 @@ def pytest_sessionfinish(session: Session) -> None:
class StepwisePlugin: class StepwisePlugin:
def __init__(self, config: Config) -> None: def __init__(self, config: Config) -> None:
self.config = config self.config = config
self.session: Optional[Session] = None self.session: Session | None = None
self.report_status = "" self.report_status = ''
assert config.cache is not None assert config.cache is not None
self.cache: Cache = config.cache self.cache: Cache = config.cache
self.lastfailed: Optional[str] = self.cache.get(STEPWISE_CACHE_DIR, None) self.lastfailed: str | None = self.cache.get(STEPWISE_CACHE_DIR, None)
self.skip: bool = config.getoption("stepwise_skip") self.skip: bool = config.getoption('stepwise_skip')
def pytest_sessionstart(self, session: Session) -> None: def pytest_sessionstart(self, session: Session) -> None:
self.session = session self.session = session
def pytest_collection_modifyitems( def pytest_collection_modifyitems(
self, config: Config, items: List[nodes.Item] self, config: Config, items: list[nodes.Item],
) -> None: ) -> None:
if not self.lastfailed: if not self.lastfailed:
self.report_status = "no previously failed tests, not skipping." self.report_status = 'no previously failed tests, not skipping.'
return return
# check all item nodes until we find a match on last failed # check all item nodes until we find a match on last failed
@ -87,9 +89,9 @@ class StepwisePlugin:
# If the previously failed test was not found among the test items, # If the previously failed test was not found among the test items,
# do not skip any tests. # do not skip any tests.
if failed_index is None: if failed_index is None:
self.report_status = "previously failed test not found, not skipping." self.report_status = 'previously failed test not found, not skipping.'
else: else:
self.report_status = f"skipping {failed_index} already passed items." self.report_status = f'skipping {failed_index} already passed items.'
deselected = items[:failed_index] deselected = items[:failed_index]
del items[:failed_index] del items[:failed_index]
config.hook.pytest_deselected(items=deselected) config.hook.pytest_deselected(items=deselected)
@ -108,23 +110,23 @@ class StepwisePlugin:
self.lastfailed = report.nodeid self.lastfailed = report.nodeid
assert self.session is not None assert self.session is not None
self.session.shouldstop = ( self.session.shouldstop = (
"Test failed, continuing from this test next run." 'Test failed, continuing from this test next run.'
) )
else: else:
# If the test was actually run and did pass. # If the test was actually run and did pass.
if report.when == "call": if report.when == 'call':
# Remove test from the failed ones, if exists. # Remove test from the failed ones, if exists.
if report.nodeid == self.lastfailed: if report.nodeid == self.lastfailed:
self.lastfailed = None self.lastfailed = None
def pytest_report_collectionfinish(self) -> Optional[str]: def pytest_report_collectionfinish(self) -> str | None:
if self.config.getoption("verbose") >= 0 and self.report_status: if self.config.getoption('verbose') >= 0 and self.report_status:
return f"stepwise: {self.report_status}" return f'stepwise: {self.report_status}'
return None return None
def pytest_sessionfinish(self) -> None: def pytest_sessionfinish(self) -> None:
if hasattr(self.config, "workerinput"): if hasattr(self.config, 'workerinput'):
# Do not update cache if this process is a xdist worker to prevent # Do not update cache if this process is a xdist worker to prevent
# race conditions (#10641). # race conditions (#10641).
return return

File diff suppressed because it is too large Load diff

View file

@ -1,12 +1,14 @@
from __future__ import annotations
import threading import threading
import traceback import traceback
import warnings
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Generator from typing import Generator
from typing import Optional from typing import Optional
from typing import Type from typing import Type
import warnings
import pytest import pytest
@ -34,22 +36,22 @@ class catch_threading_exception:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.args: Optional["threading.ExceptHookArgs"] = None self.args: threading.ExceptHookArgs | None = None
self._old_hook: Optional[Callable[["threading.ExceptHookArgs"], Any]] = None self._old_hook: Callable[[threading.ExceptHookArgs], Any] | None = None
def _hook(self, args: "threading.ExceptHookArgs") -> None: def _hook(self, args: threading.ExceptHookArgs) -> None:
self.args = args self.args = args
def __enter__(self) -> "catch_threading_exception": def __enter__(self) -> catch_threading_exception:
self._old_hook = threading.excepthook self._old_hook = threading.excepthook
threading.excepthook = self._hook threading.excepthook = self._hook
return self return self
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> None: ) -> None:
assert self._old_hook is not None assert self._old_hook is not None
threading.excepthook = self._old_hook threading.excepthook = self._old_hook
@ -64,15 +66,15 @@ def thread_exception_runtest_hook() -> Generator[None, None, None]:
finally: finally:
if cm.args: if cm.args:
thread_name = ( thread_name = (
"<unknown>" if cm.args.thread is None else cm.args.thread.name '<unknown>' if cm.args.thread is None else cm.args.thread.name
) )
msg = f"Exception in thread {thread_name}\n\n" msg = f'Exception in thread {thread_name}\n\n'
msg += "".join( msg += ''.join(
traceback.format_exception( traceback.format_exception(
cm.args.exc_type, cm.args.exc_type,
cm.args.exc_value, cm.args.exc_value,
cm.args.exc_traceback, cm.args.exc_traceback,
) ),
) )
warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg)) warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg))

View file

@ -5,10 +5,11 @@ pytest runtime information (issue #185).
Fixture "mock_timing" also interacts with this module for pytest's own tests. Fixture "mock_timing" also interacts with this module for pytest's own tests.
""" """
from __future__ import annotations
from time import perf_counter from time import perf_counter
from time import sleep from time import sleep
from time import time from time import time
__all__ = ["perf_counter", "sleep", "time"] __all__ = ['perf_counter', 'sleep', 'time']

View file

@ -1,11 +1,13 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Support for providing temporary directories to test functions.""" """Support for providing temporary directories to test functions."""
from __future__ import annotations
import dataclasses import dataclasses
import os import os
from pathlib import Path
import re import re
from shutil import rmtree
import tempfile import tempfile
from pathlib import Path
from shutil import rmtree
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import final from typing import final
@ -14,11 +16,6 @@ from typing import Literal
from typing import Optional from typing import Optional
from typing import Union from typing import Union
from .pathlib import cleanup_dead_symlinks
from .pathlib import LOCK_TIMEOUT
from .pathlib import make_numbered_dir
from .pathlib import make_numbered_dir_with_cleanup
from .pathlib import rm_rf
from _pytest.compat import get_user_id from _pytest.compat import get_user_id
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import ExitCode from _pytest.config import ExitCode
@ -32,9 +29,15 @@ from _pytest.nodes import Item
from _pytest.reports import TestReport from _pytest.reports import TestReport
from _pytest.stash import StashKey from _pytest.stash import StashKey
from .pathlib import cleanup_dead_symlinks
from .pathlib import LOCK_TIMEOUT
from .pathlib import make_numbered_dir
from .pathlib import make_numbered_dir_with_cleanup
from .pathlib import rm_rf
tmppath_result_key = StashKey[Dict[str, bool]]() tmppath_result_key = StashKey[Dict[str, bool]]()
RetentionType = Literal["all", "failed", "none"] RetentionType = Literal['all', 'failed', 'none']
@final @final
@ -45,20 +48,20 @@ class TempPathFactory:
The base directory can be configured using the ``--basetemp`` option. The base directory can be configured using the ``--basetemp`` option.
""" """
_given_basetemp: Optional[Path] _given_basetemp: Path | None
# pluggy TagTracerSub, not currently exposed, so Any. # pluggy TagTracerSub, not currently exposed, so Any.
_trace: Any _trace: Any
_basetemp: Optional[Path] _basetemp: Path | None
_retention_count: int _retention_count: int
_retention_policy: RetentionType _retention_policy: RetentionType
def __init__( def __init__(
self, self,
given_basetemp: Optional[Path], given_basetemp: Path | None,
retention_count: int, retention_count: int,
retention_policy: RetentionType, retention_policy: RetentionType,
trace, trace,
basetemp: Optional[Path] = None, basetemp: Path | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
@ -81,27 +84,27 @@ class TempPathFactory:
config: Config, config: Config,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> "TempPathFactory": ) -> TempPathFactory:
"""Create a factory according to pytest configuration. """Create a factory according to pytest configuration.
:meta private: :meta private:
""" """
check_ispytest(_ispytest) check_ispytest(_ispytest)
count = int(config.getini("tmp_path_retention_count")) count = int(config.getini('tmp_path_retention_count'))
if count < 0: if count < 0:
raise ValueError( raise ValueError(
f"tmp_path_retention_count must be >= 0. Current input: {count}." f'tmp_path_retention_count must be >= 0. Current input: {count}.',
) )
policy = config.getini("tmp_path_retention_policy") policy = config.getini('tmp_path_retention_policy')
if policy not in ("all", "failed", "none"): if policy not in ('all', 'failed', 'none'):
raise ValueError( raise ValueError(
f"tmp_path_retention_policy must be either all, failed, none. Current input: {policy}." f'tmp_path_retention_policy must be either all, failed, none. Current input: {policy}.',
) )
return cls( return cls(
given_basetemp=config.option.basetemp, given_basetemp=config.option.basetemp,
trace=config.trace.get("tmpdir"), trace=config.trace.get('tmpdir'),
retention_count=count, retention_count=count,
retention_policy=policy, retention_policy=policy,
_ispytest=True, _ispytest=True,
@ -110,7 +113,7 @@ class TempPathFactory:
def _ensure_relative_to_basetemp(self, basename: str) -> str: def _ensure_relative_to_basetemp(self, basename: str) -> str:
basename = os.path.normpath(basename) basename = os.path.normpath(basename)
if (self.getbasetemp() / basename).resolve().parent != self.getbasetemp(): if (self.getbasetemp() / basename).resolve().parent != self.getbasetemp():
raise ValueError(f"{basename} is not a normalized and relative path") raise ValueError(f'{basename} is not a normalized and relative path')
return basename return basename
def mktemp(self, basename: str, numbered: bool = True) -> Path: def mktemp(self, basename: str, numbered: bool = True) -> Path:
@ -134,7 +137,7 @@ class TempPathFactory:
p.mkdir(mode=0o700) p.mkdir(mode=0o700)
else: else:
p = make_numbered_dir(root=self.getbasetemp(), prefix=basename, mode=0o700) p = make_numbered_dir(root=self.getbasetemp(), prefix=basename, mode=0o700)
self._trace("mktemp", p) self._trace('mktemp', p)
return p return p
def getbasetemp(self) -> Path: def getbasetemp(self) -> Path:
@ -153,17 +156,17 @@ class TempPathFactory:
basetemp.mkdir(mode=0o700) basetemp.mkdir(mode=0o700)
basetemp = basetemp.resolve() basetemp = basetemp.resolve()
else: else:
from_env = os.environ.get("PYTEST_DEBUG_TEMPROOT") from_env = os.environ.get('PYTEST_DEBUG_TEMPROOT')
temproot = Path(from_env or tempfile.gettempdir()).resolve() temproot = Path(from_env or tempfile.gettempdir()).resolve()
user = get_user() or "unknown" user = get_user() or 'unknown'
# use a sub-directory in the temproot to speed-up # use a sub-directory in the temproot to speed-up
# make_numbered_dir() call # make_numbered_dir() call
rootdir = temproot.joinpath(f"pytest-of-{user}") rootdir = temproot.joinpath(f'pytest-of-{user}')
try: try:
rootdir.mkdir(mode=0o700, exist_ok=True) rootdir.mkdir(mode=0o700, exist_ok=True)
except OSError: except OSError:
# getuser() likely returned illegal characters for the platform, use unknown back off mechanism # getuser() likely returned illegal characters for the platform, use unknown back off mechanism
rootdir = temproot.joinpath("pytest-of-unknown") rootdir = temproot.joinpath('pytest-of-unknown')
rootdir.mkdir(mode=0o700, exist_ok=True) rootdir.mkdir(mode=0o700, exist_ok=True)
# Because we use exist_ok=True with a predictable name, make sure # Because we use exist_ok=True with a predictable name, make sure
# we are the owners, to prevent any funny business (on unix, where # we are the owners, to prevent any funny business (on unix, where
@ -176,16 +179,16 @@ class TempPathFactory:
rootdir_stat = rootdir.stat() rootdir_stat = rootdir.stat()
if rootdir_stat.st_uid != uid: if rootdir_stat.st_uid != uid:
raise OSError( raise OSError(
f"The temporary directory {rootdir} is not owned by the current user. " f'The temporary directory {rootdir} is not owned by the current user. '
"Fix this and try again." 'Fix this and try again.',
) )
if (rootdir_stat.st_mode & 0o077) != 0: if (rootdir_stat.st_mode & 0o077) != 0:
os.chmod(rootdir, rootdir_stat.st_mode & ~0o077) os.chmod(rootdir, rootdir_stat.st_mode & ~0o077)
keep = self._retention_count keep = self._retention_count
if self._retention_policy == "none": if self._retention_policy == 'none':
keep = 0 keep = 0
basetemp = make_numbered_dir_with_cleanup( basetemp = make_numbered_dir_with_cleanup(
prefix="pytest-", prefix='pytest-',
root=rootdir, root=rootdir,
keep=keep, keep=keep,
lock_timeout=LOCK_TIMEOUT, lock_timeout=LOCK_TIMEOUT,
@ -193,11 +196,11 @@ class TempPathFactory:
) )
assert basetemp is not None, basetemp assert basetemp is not None, basetemp
self._basetemp = basetemp self._basetemp = basetemp
self._trace("new basetemp", basetemp) self._trace('new basetemp', basetemp)
return basetemp return basetemp
def get_user() -> Optional[str]: def get_user() -> str | None:
"""Return the current user name, or None if getuser() does not work """Return the current user name, or None if getuser() does not work
in the current environment (see #1010).""" in the current environment (see #1010)."""
try: try:
@ -219,25 +222,25 @@ def pytest_configure(config: Config) -> None:
mp = MonkeyPatch() mp = MonkeyPatch()
config.add_cleanup(mp.undo) config.add_cleanup(mp.undo)
_tmp_path_factory = TempPathFactory.from_config(config, _ispytest=True) _tmp_path_factory = TempPathFactory.from_config(config, _ispytest=True)
mp.setattr(config, "_tmp_path_factory", _tmp_path_factory, raising=False) mp.setattr(config, '_tmp_path_factory', _tmp_path_factory, raising=False)
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
parser.addini( parser.addini(
"tmp_path_retention_count", 'tmp_path_retention_count',
help="How many sessions should we keep the `tmp_path` directories, according to `tmp_path_retention_policy`.", help='How many sessions should we keep the `tmp_path` directories, according to `tmp_path_retention_policy`.',
default=3, default=3,
) )
parser.addini( parser.addini(
"tmp_path_retention_policy", 'tmp_path_retention_policy',
help="Controls which directories created by the `tmp_path` fixture are kept around, based on test outcome. " help='Controls which directories created by the `tmp_path` fixture are kept around, based on test outcome. '
"(all/failed/none)", '(all/failed/none)',
default="all", default='all',
) )
@fixture(scope="session") @fixture(scope='session')
def tmp_path_factory(request: FixtureRequest) -> TempPathFactory: def tmp_path_factory(request: FixtureRequest) -> TempPathFactory:
"""Return a :class:`pytest.TempPathFactory` instance for the test session.""" """Return a :class:`pytest.TempPathFactory` instance for the test session."""
# Set dynamically by pytest_configure() above. # Set dynamically by pytest_configure() above.
@ -246,7 +249,7 @@ def tmp_path_factory(request: FixtureRequest) -> TempPathFactory:
def _mk_tmp(request: FixtureRequest, factory: TempPathFactory) -> Path: def _mk_tmp(request: FixtureRequest, factory: TempPathFactory) -> Path:
name = request.node.name name = request.node.name
name = re.sub(r"[\W]", "_", name) name = re.sub(r'[\W]', '_', name)
MAXVAL = 30 MAXVAL = 30
name = name[:MAXVAL] name = name[:MAXVAL]
return factory.mktemp(name, numbered=True) return factory.mktemp(name, numbered=True)
@ -254,7 +257,7 @@ def _mk_tmp(request: FixtureRequest, factory: TempPathFactory) -> Path:
@fixture @fixture
def tmp_path( def tmp_path(
request: FixtureRequest, tmp_path_factory: TempPathFactory request: FixtureRequest, tmp_path_factory: TempPathFactory,
) -> Generator[Path, None, None]: ) -> Generator[Path, None, None]:
"""Return a temporary directory path object which is unique to each test """Return a temporary directory path object which is unique to each test
function invocation, created as a sub directory of the base temporary function invocation, created as a sub directory of the base temporary
@ -277,7 +280,7 @@ def tmp_path(
policy = tmp_path_factory._retention_policy policy = tmp_path_factory._retention_policy
result_dict = request.node.stash[tmppath_result_key] result_dict = request.node.stash[tmppath_result_key]
if policy == "failed" and result_dict.get("call", True): if policy == 'failed' and result_dict.get('call', True):
# We do a "best effort" to remove files, but it might not be possible due to some leaked resource, # We do a "best effort" to remove files, but it might not be possible due to some leaked resource,
# permissions, etc, in which case we ignore it. # permissions, etc, in which case we ignore it.
rmtree(path, ignore_errors=True) rmtree(path, ignore_errors=True)
@ -285,7 +288,7 @@ def tmp_path(
del request.node.stash[tmppath_result_key] del request.node.stash[tmppath_result_key]
def pytest_sessionfinish(session, exitstatus: Union[int, ExitCode]): def pytest_sessionfinish(session, exitstatus: int | ExitCode):
"""After each session, remove base directory if all the tests passed, """After each session, remove base directory if all the tests passed,
the policy is "failed", and the basetemp is not specified by a user. the policy is "failed", and the basetemp is not specified by a user.
""" """
@ -296,9 +299,9 @@ def pytest_sessionfinish(session, exitstatus: Union[int, ExitCode]):
policy = tmp_path_factory._retention_policy policy = tmp_path_factory._retention_policy
if ( if (
exitstatus == 0 exitstatus == 0 and
and policy == "failed" policy == 'failed' and
and tmp_path_factory._given_basetemp is None tmp_path_factory._given_basetemp is None
): ):
if basetemp.is_dir(): if basetemp.is_dir():
# We do a "best effort" to remove files, but it might not be possible due to some leaked resource, # We do a "best effort" to remove files, but it might not be possible due to some leaked resource,
@ -312,10 +315,10 @@ def pytest_sessionfinish(session, exitstatus: Union[int, ExitCode]):
@hookimpl(wrapper=True, tryfirst=True) @hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_makereport( def pytest_runtest_makereport(
item: Item, call item: Item, call,
) -> Generator[None, TestReport, TestReport]: ) -> Generator[None, TestReport, TestReport]:
rep = yield rep = yield
assert rep.when is not None assert rep.when is not None
empty: Dict[str, bool] = {} empty: dict[str, bool] = {}
item.stash.setdefault(tmppath_result_key, empty)[rep.when] = rep.passed item.stash.setdefault(tmppath_result_key, empty)[rep.when] = rep.passed
return rep return rep

View file

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Discover and run std-library "unittest" style tests.""" """Discover and run std-library "unittest" style tests."""
from __future__ import annotations
import sys import sys
import traceback import traceback
import types import types
@ -15,6 +17,7 @@ from typing import TYPE_CHECKING
from typing import Union from typing import Union
import _pytest._code import _pytest._code
import pytest
from _pytest.compat import getimfunc from _pytest.compat import getimfunc
from _pytest.compat import is_async_function from _pytest.compat import is_async_function
from _pytest.config import hookimpl from _pytest.config import hookimpl
@ -29,7 +32,6 @@ from _pytest.python import Class
from _pytest.python import Function from _pytest.python import Function
from _pytest.python import Module from _pytest.python import Module
from _pytest.runner import CallInfo from _pytest.runner import CallInfo
import pytest
if TYPE_CHECKING: if TYPE_CHECKING:
@ -44,11 +46,11 @@ if TYPE_CHECKING:
def pytest_pycollect_makeitem( def pytest_pycollect_makeitem(
collector: Union[Module, Class], name: str, obj: object collector: Module | Class, name: str, obj: object,
) -> Optional["UnitTestCase"]: ) -> UnitTestCase | None:
# Has unittest been imported and is obj a subclass of its TestCase? # Has unittest been imported and is obj a subclass of its TestCase?
try: try:
ut = sys.modules["unittest"] ut = sys.modules['unittest']
# Type ignored because `ut` is an opaque module. # Type ignored because `ut` is an opaque module.
if not issubclass(obj, ut.TestCase): # type: ignore if not issubclass(obj, ut.TestCase): # type: ignore
return None return None
@ -63,11 +65,11 @@ class UnitTestCase(Class):
# to declare that our children do not support funcargs. # to declare that our children do not support funcargs.
nofuncargs = True nofuncargs = True
def collect(self) -> Iterable[Union[Item, Collector]]: def collect(self) -> Iterable[Item | Collector]:
from unittest import TestLoader from unittest import TestLoader
cls = self.obj cls = self.obj
if not getattr(cls, "__test__", True): if not getattr(cls, '__test__', True):
return return
skipped = _is_skipped(cls) skipped = _is_skipped(cls)
@ -81,28 +83,28 @@ class UnitTestCase(Class):
foundsomething = False foundsomething = False
for name in loader.getTestCaseNames(self.obj): for name in loader.getTestCaseNames(self.obj):
x = getattr(self.obj, name) x = getattr(self.obj, name)
if not getattr(x, "__test__", True): if not getattr(x, '__test__', True):
continue continue
funcobj = getimfunc(x) funcobj = getimfunc(x)
yield TestCaseFunction.from_parent(self, name=name, callobj=funcobj) yield TestCaseFunction.from_parent(self, name=name, callobj=funcobj)
foundsomething = True foundsomething = True
if not foundsomething: if not foundsomething:
runtest = getattr(self.obj, "runTest", None) runtest = getattr(self.obj, 'runTest', None)
if runtest is not None: if runtest is not None:
ut = sys.modules.get("twisted.trial.unittest", None) ut = sys.modules.get('twisted.trial.unittest', None)
# Type ignored because `ut` is an opaque module. # Type ignored because `ut` is an opaque module.
if ut is None or runtest != ut.TestCase.runTest: # type: ignore if ut is None or runtest != ut.TestCase.runTest: # type: ignore
yield TestCaseFunction.from_parent(self, name="runTest") yield TestCaseFunction.from_parent(self, name='runTest')
def _register_unittest_setup_class_fixture(self, cls: type) -> None: def _register_unittest_setup_class_fixture(self, cls: type) -> None:
"""Register an auto-use fixture to invoke setUpClass and """Register an auto-use fixture to invoke setUpClass and
tearDownClass (#517).""" tearDownClass (#517)."""
setup = getattr(cls, "setUpClass", None) setup = getattr(cls, 'setUpClass', None)
teardown = getattr(cls, "tearDownClass", None) teardown = getattr(cls, 'tearDownClass', None)
if setup is None and teardown is None: if setup is None and teardown is None:
return None return None
cleanup = getattr(cls, "doClassCleanups", lambda: None) cleanup = getattr(cls, 'doClassCleanups', lambda: None)
def unittest_setup_class_fixture( def unittest_setup_class_fixture(
request: FixtureRequest, request: FixtureRequest,
@ -128,18 +130,18 @@ class UnitTestCase(Class):
self.session._fixturemanager._register_fixture( self.session._fixturemanager._register_fixture(
# Use a unique name to speed up lookup. # Use a unique name to speed up lookup.
name=f"_unittest_setUpClass_fixture_{cls.__qualname__}", name=f'_unittest_setUpClass_fixture_{cls.__qualname__}',
func=unittest_setup_class_fixture, func=unittest_setup_class_fixture,
nodeid=self.nodeid, nodeid=self.nodeid,
scope="class", scope='class',
autouse=True, autouse=True,
) )
def _register_unittest_setup_method_fixture(self, cls: type) -> None: def _register_unittest_setup_method_fixture(self, cls: type) -> None:
"""Register an auto-use fixture to invoke setup_method and """Register an auto-use fixture to invoke setup_method and
teardown_method (#517).""" teardown_method (#517)."""
setup = getattr(cls, "setup_method", None) setup = getattr(cls, 'setup_method', None)
teardown = getattr(cls, "teardown_method", None) teardown = getattr(cls, 'teardown_method', None)
if setup is None and teardown is None: if setup is None and teardown is None:
return None return None
@ -158,18 +160,18 @@ class UnitTestCase(Class):
self.session._fixturemanager._register_fixture( self.session._fixturemanager._register_fixture(
# Use a unique name to speed up lookup. # Use a unique name to speed up lookup.
name=f"_unittest_setup_method_fixture_{cls.__qualname__}", name=f'_unittest_setup_method_fixture_{cls.__qualname__}',
func=unittest_setup_method_fixture, func=unittest_setup_method_fixture,
nodeid=self.nodeid, nodeid=self.nodeid,
scope="function", scope='function',
autouse=True, autouse=True,
) )
class TestCaseFunction(Function): class TestCaseFunction(Function):
nofuncargs = True nofuncargs = True
_excinfo: Optional[List[_pytest._code.ExceptionInfo[BaseException]]] = None _excinfo: list[_pytest._code.ExceptionInfo[BaseException]] | None = None
_testcase: Optional["unittest.TestCase"] = None _testcase: unittest.TestCase | None = None
def _getobj(self): def _getobj(self):
assert self.parent is not None assert self.parent is not None
@ -182,7 +184,7 @@ class TestCaseFunction(Function):
def setup(self) -> None: def setup(self) -> None:
# A bound method to be called during teardown() if set (see 'runtest()'). # A bound method to be called during teardown() if set (see 'runtest()').
self._explicit_tearDown: Optional[Callable[[], None]] = None self._explicit_tearDown: Callable[[], None] | None = None
assert self.parent is not None assert self.parent is not None
self._testcase = self.parent.obj(self.name) # type: ignore[attr-defined] self._testcase = self.parent.obj(self.name) # type: ignore[attr-defined]
self._obj = getattr(self._testcase, self.name) self._obj = getattr(self._testcase, self.name)
@ -196,15 +198,15 @@ class TestCaseFunction(Function):
self._testcase = None self._testcase = None
self._obj = None self._obj = None
def startTest(self, testcase: "unittest.TestCase") -> None: def startTest(self, testcase: unittest.TestCase) -> None:
pass pass
def _addexcinfo(self, rawexcinfo: "_SysExcInfoType") -> None: def _addexcinfo(self, rawexcinfo: _SysExcInfoType) -> None:
# Unwrap potential exception info (see twisted trial support below). # Unwrap potential exception info (see twisted trial support below).
rawexcinfo = getattr(rawexcinfo, "_rawexcinfo", rawexcinfo) rawexcinfo = getattr(rawexcinfo, '_rawexcinfo', rawexcinfo)
try: try:
excinfo = _pytest._code.ExceptionInfo[BaseException].from_exc_info( excinfo = _pytest._code.ExceptionInfo[BaseException].from_exc_info(
rawexcinfo # type: ignore[arg-type] rawexcinfo, # type: ignore[arg-type]
) )
# Invoke the attributes to trigger storing the traceback # Invoke the attributes to trigger storing the traceback
# trial causes some issue there. # trial causes some issue there.
@ -216,26 +218,26 @@ class TestCaseFunction(Function):
values = traceback.format_exception(*rawexcinfo) values = traceback.format_exception(*rawexcinfo)
values.insert( values.insert(
0, 0,
"NOTE: Incompatible Exception Representation, " 'NOTE: Incompatible Exception Representation, '
"displaying natively:\n\n", 'displaying natively:\n\n',
) )
fail("".join(values), pytrace=False) fail(''.join(values), pytrace=False)
except (fail.Exception, KeyboardInterrupt): except (fail.Exception, KeyboardInterrupt):
raise raise
except BaseException: except BaseException:
fail( fail(
"ERROR: Unknown Incompatible Exception " 'ERROR: Unknown Incompatible Exception '
f"representation:\n{rawexcinfo!r}", f'representation:\n{rawexcinfo!r}',
pytrace=False, pytrace=False,
) )
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
except fail.Exception: except fail.Exception:
excinfo = _pytest._code.ExceptionInfo.from_current() excinfo = _pytest._code.ExceptionInfo.from_current()
self.__dict__.setdefault("_excinfo", []).append(excinfo) self.__dict__.setdefault('_excinfo', []).append(excinfo)
def addError( def addError(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType" self, testcase: unittest.TestCase, rawexcinfo: _SysExcInfoType,
) -> None: ) -> None:
try: try:
if isinstance(rawexcinfo[1], exit.Exception): if isinstance(rawexcinfo[1], exit.Exception):
@ -245,11 +247,11 @@ class TestCaseFunction(Function):
self._addexcinfo(rawexcinfo) self._addexcinfo(rawexcinfo)
def addFailure( def addFailure(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType" self, testcase: unittest.TestCase, rawexcinfo: _SysExcInfoType,
) -> None: ) -> None:
self._addexcinfo(rawexcinfo) self._addexcinfo(rawexcinfo)
def addSkip(self, testcase: "unittest.TestCase", reason: str) -> None: def addSkip(self, testcase: unittest.TestCase, reason: str) -> None:
try: try:
raise pytest.skip.Exception(reason, _use_item_location=True) raise pytest.skip.Exception(reason, _use_item_location=True)
except skip.Exception: except skip.Exception:
@ -257,9 +259,9 @@ class TestCaseFunction(Function):
def addExpectedFailure( def addExpectedFailure(
self, self,
testcase: "unittest.TestCase", testcase: unittest.TestCase,
rawexcinfo: "_SysExcInfoType", rawexcinfo: _SysExcInfoType,
reason: str = "", reason: str = '',
) -> None: ) -> None:
try: try:
xfail(str(reason)) xfail(str(reason))
@ -268,25 +270,25 @@ class TestCaseFunction(Function):
def addUnexpectedSuccess( def addUnexpectedSuccess(
self, self,
testcase: "unittest.TestCase", testcase: unittest.TestCase,
reason: Optional["twisted.trial.unittest.Todo"] = None, reason: twisted.trial.unittest.Todo | None = None,
) -> None: ) -> None:
msg = "Unexpected success" msg = 'Unexpected success'
if reason: if reason:
msg += f": {reason.reason}" msg += f': {reason.reason}'
# Preserve unittest behaviour - fail the test. Explicitly not an XPASS. # Preserve unittest behaviour - fail the test. Explicitly not an XPASS.
try: try:
fail(msg, pytrace=False) fail(msg, pytrace=False)
except fail.Exception: except fail.Exception:
self._addexcinfo(sys.exc_info()) self._addexcinfo(sys.exc_info())
def addSuccess(self, testcase: "unittest.TestCase") -> None: def addSuccess(self, testcase: unittest.TestCase) -> None:
pass pass
def stopTest(self, testcase: "unittest.TestCase") -> None: def stopTest(self, testcase: unittest.TestCase) -> None:
pass pass
def addDuration(self, testcase: "unittest.TestCase", elapsed: float) -> None: def addDuration(self, testcase: unittest.TestCase, elapsed: float) -> None:
pass pass
def runtest(self) -> None: def runtest(self) -> None:
@ -310,9 +312,9 @@ class TestCaseFunction(Function):
# We need to consider if the test itself is skipped, or the whole class. # We need to consider if the test itself is skipped, or the whole class.
assert isinstance(self.parent, UnitTestCase) assert isinstance(self.parent, UnitTestCase)
skipped = _is_skipped(self.obj) or _is_skipped(self.parent.obj) skipped = _is_skipped(self.obj) or _is_skipped(self.parent.obj)
if self.config.getoption("usepdb") and not skipped: if self.config.getoption('usepdb') and not skipped:
self._explicit_tearDown = self._testcase.tearDown self._explicit_tearDown = self._testcase.tearDown
setattr(self._testcase, "tearDown", lambda *args: None) setattr(self._testcase, 'tearDown', lambda *args: None)
# We need to update the actual bound method with self.obj, because # We need to update the actual bound method with self.obj, because
# wrap_pytest_function_for_tracing replaces self.obj by a wrapper. # wrap_pytest_function_for_tracing replaces self.obj by a wrapper.
@ -323,11 +325,11 @@ class TestCaseFunction(Function):
delattr(self._testcase, self.name) delattr(self._testcase, self.name)
def _traceback_filter( def _traceback_filter(
self, excinfo: _pytest._code.ExceptionInfo[BaseException] self, excinfo: _pytest._code.ExceptionInfo[BaseException],
) -> _pytest._code.Traceback: ) -> _pytest._code.Traceback:
traceback = super()._traceback_filter(excinfo) traceback = super()._traceback_filter(excinfo)
ntraceback = traceback.filter( ntraceback = traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest"), lambda x: not x.frame.f_globals.get('__unittest'),
) )
if not ntraceback: if not ntraceback:
ntraceback = traceback ntraceback = traceback
@ -348,13 +350,13 @@ def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> None:
# This is actually only needed for nose, which reuses unittest.SkipTest for # This is actually only needed for nose, which reuses unittest.SkipTest for
# its own nose.SkipTest. For unittest TestCases, SkipTest is already # its own nose.SkipTest. For unittest TestCases, SkipTest is already
# handled internally, and doesn't reach here. # handled internally, and doesn't reach here.
unittest = sys.modules.get("unittest") unittest = sys.modules.get('unittest')
if ( if (
unittest and call.excinfo and isinstance(call.excinfo.value, unittest.SkipTest) # type: ignore[attr-defined] unittest and call.excinfo and isinstance(call.excinfo.value, unittest.SkipTest) # type: ignore[attr-defined]
): ):
excinfo = call.excinfo excinfo = call.excinfo
call2 = CallInfo[None].from_call( call2 = CallInfo[None].from_call(
lambda: pytest.skip(str(excinfo.value)), call.when lambda: pytest.skip(str(excinfo.value)), call.when,
) )
call.excinfo = call2.excinfo call.excinfo = call2.excinfo
@ -365,8 +367,8 @@ classImplements_has_run = False
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]: def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
if isinstance(item, TestCaseFunction) and "twisted.trial.unittest" in sys.modules: if isinstance(item, TestCaseFunction) and 'twisted.trial.unittest' in sys.modules:
ut: Any = sys.modules["twisted.python.failure"] ut: Any = sys.modules['twisted.python.failure']
global classImplements_has_run global classImplements_has_run
Failure__init__ = ut.Failure.__init__ Failure__init__ = ut.Failure.__init__
if not classImplements_has_run: if not classImplements_has_run:
@ -377,7 +379,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
classImplements_has_run = True classImplements_has_run = True
def excstore( def excstore(
self, exc_value=None, exc_type=None, exc_tb=None, captureVars=None self, exc_value=None, exc_type=None, exc_tb=None, captureVars=None,
): ):
if exc_value is None: if exc_value is None:
self._rawexcinfo = sys.exc_info() self._rawexcinfo = sys.exc_info()
@ -387,7 +389,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
self._rawexcinfo = (exc_type, exc_value, exc_tb) self._rawexcinfo = (exc_type, exc_value, exc_tb)
try: try:
Failure__init__( Failure__init__(
self, exc_value, exc_type, exc_tb, captureVars=captureVars self, exc_value, exc_type, exc_tb, captureVars=captureVars,
) )
except TypeError: except TypeError:
Failure__init__(self, exc_value, exc_type, exc_tb) Failure__init__(self, exc_value, exc_type, exc_tb)
@ -404,4 +406,4 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
def _is_skipped(obj) -> bool: def _is_skipped(obj) -> bool:
"""Return True if the given object has been marked with @unittest.skip.""" """Return True if the given object has been marked with @unittest.skip."""
return bool(getattr(obj, "__unittest_skip__", False)) return bool(getattr(obj, '__unittest_skip__', False))

View file

@ -1,12 +1,14 @@
from __future__ import annotations
import sys import sys
import traceback import traceback
import warnings
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Generator from typing import Generator
from typing import Optional from typing import Optional
from typing import Type from typing import Type
import warnings
import pytest import pytest
@ -34,24 +36,24 @@ class catch_unraisable_exception:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.unraisable: Optional["sys.UnraisableHookArgs"] = None self.unraisable: sys.UnraisableHookArgs | None = None
self._old_hook: Optional[Callable[["sys.UnraisableHookArgs"], Any]] = None self._old_hook: Callable[[sys.UnraisableHookArgs], Any] | None = None
def _hook(self, unraisable: "sys.UnraisableHookArgs") -> None: def _hook(self, unraisable: sys.UnraisableHookArgs) -> None:
# Storing unraisable.object can resurrect an object which is being # Storing unraisable.object can resurrect an object which is being
# finalized. Storing unraisable.exc_value creates a reference cycle. # finalized. Storing unraisable.exc_value creates a reference cycle.
self.unraisable = unraisable self.unraisable = unraisable
def __enter__(self) -> "catch_unraisable_exception": def __enter__(self) -> catch_unraisable_exception:
self._old_hook = sys.unraisablehook self._old_hook = sys.unraisablehook
sys.unraisablehook = self._hook sys.unraisablehook = self._hook
return self return self
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> None: ) -> None:
assert self._old_hook is not None assert self._old_hook is not None
sys.unraisablehook = self._old_hook sys.unraisablehook = self._old_hook
@ -68,14 +70,14 @@ def unraisable_exception_runtest_hook() -> Generator[None, None, None]:
if cm.unraisable.err_msg is not None: if cm.unraisable.err_msg is not None:
err_msg = cm.unraisable.err_msg err_msg = cm.unraisable.err_msg
else: else:
err_msg = "Exception ignored in" err_msg = 'Exception ignored in'
msg = f"{err_msg}: {cm.unraisable.object!r}\n\n" msg = f'{err_msg}: {cm.unraisable.object!r}\n\n'
msg += "".join( msg += ''.join(
traceback.format_exception( traceback.format_exception(
cm.unraisable.exc_type, cm.unraisable.exc_type,
cm.unraisable.exc_value, cm.unraisable.exc_value,
cm.unraisable.exc_traceback, cm.unraisable.exc_traceback,
) ),
) )
warnings.warn(pytest.PytestUnraisableExceptionWarning(msg)) warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))

View file

@ -1,64 +1,66 @@
from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
import warnings
from types import FunctionType from types import FunctionType
from typing import Any from typing import Any
from typing import final from typing import final
from typing import Generic from typing import Generic
from typing import Type from typing import Type
from typing import TypeVar from typing import TypeVar
import warnings
class PytestWarning(UserWarning): class PytestWarning(UserWarning):
"""Base class for all warnings emitted by pytest.""" """Base class for all warnings emitted by pytest."""
__module__ = "pytest" __module__ = 'pytest'
@final @final
class PytestAssertRewriteWarning(PytestWarning): class PytestAssertRewriteWarning(PytestWarning):
"""Warning emitted by the pytest assert rewrite module.""" """Warning emitted by the pytest assert rewrite module."""
__module__ = "pytest" __module__ = 'pytest'
@final @final
class PytestCacheWarning(PytestWarning): class PytestCacheWarning(PytestWarning):
"""Warning emitted by the cache plugin in various situations.""" """Warning emitted by the cache plugin in various situations."""
__module__ = "pytest" __module__ = 'pytest'
@final @final
class PytestConfigWarning(PytestWarning): class PytestConfigWarning(PytestWarning):
"""Warning emitted for configuration issues.""" """Warning emitted for configuration issues."""
__module__ = "pytest" __module__ = 'pytest'
@final @final
class PytestCollectionWarning(PytestWarning): class PytestCollectionWarning(PytestWarning):
"""Warning emitted when pytest is not able to collect a file or symbol in a module.""" """Warning emitted when pytest is not able to collect a file or symbol in a module."""
__module__ = "pytest" __module__ = 'pytest'
class PytestDeprecationWarning(PytestWarning, DeprecationWarning): class PytestDeprecationWarning(PytestWarning, DeprecationWarning):
"""Warning class for features that will be removed in a future version.""" """Warning class for features that will be removed in a future version."""
__module__ = "pytest" __module__ = 'pytest'
class PytestRemovedIn9Warning(PytestDeprecationWarning): class PytestRemovedIn9Warning(PytestDeprecationWarning):
"""Warning class for features that will be removed in pytest 9.""" """Warning class for features that will be removed in pytest 9."""
__module__ = "pytest" __module__ = 'pytest'
class PytestReturnNotNoneWarning(PytestWarning): class PytestReturnNotNoneWarning(PytestWarning):
"""Warning emitted when a test function is returning value other than None.""" """Warning emitted when a test function is returning value other than None."""
__module__ = "pytest" __module__ = 'pytest'
@final @final
@ -69,11 +71,11 @@ class PytestExperimentalApiWarning(PytestWarning, FutureWarning):
future version. future version.
""" """
__module__ = "pytest" __module__ = 'pytest'
@classmethod @classmethod
def simple(cls, apiname: str) -> "PytestExperimentalApiWarning": def simple(cls, apiname: str) -> PytestExperimentalApiWarning:
return cls(f"{apiname} is an experimental api that may change over time") return cls(f'{apiname} is an experimental api that may change over time')
@final @final
@ -85,7 +87,7 @@ class PytestUnhandledCoroutineWarning(PytestReturnNotNoneWarning):
Coroutine test functions are not natively supported. Coroutine test functions are not natively supported.
""" """
__module__ = "pytest" __module__ = 'pytest'
@final @final
@ -95,7 +97,7 @@ class PytestUnknownMarkWarning(PytestWarning):
See :ref:`mark` for details. See :ref:`mark` for details.
""" """
__module__ = "pytest" __module__ = 'pytest'
@final @final
@ -107,7 +109,7 @@ class PytestUnraisableExceptionWarning(PytestWarning):
as normal. as normal.
""" """
__module__ = "pytest" __module__ = 'pytest'
@final @final
@ -117,10 +119,10 @@ class PytestUnhandledThreadExceptionWarning(PytestWarning):
Such exceptions don't propagate normally. Such exceptions don't propagate normally.
""" """
__module__ = "pytest" __module__ = 'pytest'
_W = TypeVar("_W", bound=PytestWarning) _W = TypeVar('_W', bound=PytestWarning)
@final @final
@ -132,7 +134,7 @@ class UnformattedWarning(Generic[_W]):
as opposed to a direct message. as opposed to a direct message.
""" """
category: Type["_W"] category: type[_W]
template: str template: str
def format(self, **kwargs: Any) -> _W: def format(self, **kwargs: Any) -> _W:
@ -157,9 +159,9 @@ def warn_explicit_for(method: FunctionType, message: PytestWarning) -> None:
type(message), type(message),
filename=filename, filename=filename,
module=module, module=module,
registry=mod_globals.setdefault("__warningregistry__", {}), registry=mod_globals.setdefault('__warningregistry__', {}),
lineno=lineno, lineno=lineno,
) )
except Warning as w: except Warning as w:
# If warnings are errors (e.g. -Werror), location information gets lost, so we add it to the message. # If warnings are errors (e.g. -Werror), location information gets lost, so we add it to the message.
raise type(w)(f"{w}\n at {filename}:{lineno}") from None raise type(w)(f'{w}\n at {filename}:{lineno}') from None

View file

@ -1,25 +1,27 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from contextlib import contextmanager from __future__ import annotations
import sys import sys
import warnings
from contextlib import contextmanager
from typing import Generator from typing import Generator
from typing import Literal from typing import Literal
from typing import Optional from typing import Optional
import warnings
import pytest
from _pytest.config import apply_warning_filters from _pytest.config import apply_warning_filters
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import parse_warning_filter from _pytest.config import parse_warning_filter
from _pytest.main import Session from _pytest.main import Session
from _pytest.nodes import Item from _pytest.nodes import Item
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
import pytest
def pytest_configure(config: Config) -> None: def pytest_configure(config: Config) -> None:
config.addinivalue_line( config.addinivalue_line(
"markers", 'markers',
"filterwarnings(warning): add a warning filter to the given test. " 'filterwarnings(warning): add a warning filter to the given test. '
"see https://docs.pytest.org/en/stable/how-to/capture-warnings.html#pytest-mark-filterwarnings ", 'see https://docs.pytest.org/en/stable/how-to/capture-warnings.html#pytest-mark-filterwarnings ',
) )
@ -27,8 +29,8 @@ def pytest_configure(config: Config) -> None:
def catch_warnings_for_item( def catch_warnings_for_item(
config: Config, config: Config,
ihook, ihook,
when: Literal["config", "collect", "runtest"], when: Literal['config', 'collect', 'runtest'],
item: Optional[Item], item: Item | None,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
"""Context manager that catches warnings generated in the contained execution block. """Context manager that catches warnings generated in the contained execution block.
@ -36,7 +38,7 @@ def catch_warnings_for_item(
Each warning captured triggers the ``pytest_warning_recorded`` hook. Each warning captured triggers the ``pytest_warning_recorded`` hook.
""" """
config_filters = config.getini("filterwarnings") config_filters = config.getini('filterwarnings')
cmdline_filters = config.known_args_namespace.pythonwarnings or [] cmdline_filters = config.known_args_namespace.pythonwarnings or []
with warnings.catch_warnings(record=True) as log: with warnings.catch_warnings(record=True) as log:
# mypy can't infer that record=True means log is not None; help it. # mypy can't infer that record=True means log is not None; help it.
@ -44,8 +46,8 @@ def catch_warnings_for_item(
if not sys.warnoptions: if not sys.warnoptions:
# If user is not explicitly configuring warning filters, show deprecation warnings by default (#2908). # If user is not explicitly configuring warning filters, show deprecation warnings by default (#2908).
warnings.filterwarnings("always", category=DeprecationWarning) warnings.filterwarnings('always', category=DeprecationWarning)
warnings.filterwarnings("always", category=PendingDeprecationWarning) warnings.filterwarnings('always', category=PendingDeprecationWarning)
# To be enabled in pytest 9.0.0. # To be enabled in pytest 9.0.0.
# warnings.filterwarnings("error", category=pytest.PytestRemovedIn9Warning) # warnings.filterwarnings("error", category=pytest.PytestRemovedIn9Warning)
@ -53,9 +55,9 @@ def catch_warnings_for_item(
apply_warning_filters(config_filters, cmdline_filters) apply_warning_filters(config_filters, cmdline_filters)
# apply filters from "filterwarnings" marks # apply filters from "filterwarnings" marks
nodeid = "" if item is None else item.nodeid nodeid = '' if item is None else item.nodeid
if item is not None: if item is not None:
for mark in item.iter_markers(name="filterwarnings"): for mark in item.iter_markers(name='filterwarnings'):
for arg in mark.args: for arg in mark.args:
warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
@ -69,7 +71,7 @@ def catch_warnings_for_item(
nodeid=nodeid, nodeid=nodeid,
when=when, when=when,
location=None, location=None,
) ),
) )
@ -91,22 +93,22 @@ def warning_record_to_str(warning_message: warnings.WarningMessage) -> str:
else: else:
tb = tracemalloc.get_object_traceback(warning_message.source) tb = tracemalloc.get_object_traceback(warning_message.source)
if tb is not None: if tb is not None:
formatted_tb = "\n".join(tb.format()) formatted_tb = '\n'.join(tb.format())
# Use a leading new line to better separate the (large) output # Use a leading new line to better separate the (large) output
# from the traceback to the previous warning text. # from the traceback to the previous warning text.
msg += f"\nObject allocated at:\n{formatted_tb}" msg += f'\nObject allocated at:\n{formatted_tb}'
else: else:
# No need for a leading new line. # No need for a leading new line.
url = "https://docs.pytest.org/en/stable/how-to/capture-warnings.html#resource-warnings" url = 'https://docs.pytest.org/en/stable/how-to/capture-warnings.html#resource-warnings'
msg += "Enable tracemalloc to get traceback where the object was allocated.\n" msg += 'Enable tracemalloc to get traceback where the object was allocated.\n'
msg += f"See {url} for more info." msg += f'See {url} for more info.'
return msg return msg
@pytest.hookimpl(wrapper=True, tryfirst=True) @pytest.hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]: def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
with catch_warnings_for_item( with catch_warnings_for_item(
config=item.config, ihook=item.ihook, when="runtest", item=item config=item.config, ihook=item.ihook, when='runtest', item=item,
): ):
return (yield) return (yield)
@ -115,7 +117,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
def pytest_collection(session: Session) -> Generator[None, object, object]: def pytest_collection(session: Session) -> Generator[None, object, object]:
config = session.config config = session.config
with catch_warnings_for_item( with catch_warnings_for_item(
config=config, ihook=config.hook, when="collect", item=None config=config, ihook=config.hook, when='collect', item=None,
): ):
return (yield) return (yield)
@ -126,7 +128,7 @@ def pytest_terminal_summary(
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
config = terminalreporter.config config = terminalreporter.config
with catch_warnings_for_item( with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None config=config, ihook=config.hook, when='config', item=None,
): ):
return (yield) return (yield)
@ -135,16 +137,16 @@ def pytest_terminal_summary(
def pytest_sessionfinish(session: Session) -> Generator[None, None, None]: def pytest_sessionfinish(session: Session) -> Generator[None, None, None]:
config = session.config config = session.config
with catch_warnings_for_item( with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None config=config, ihook=config.hook, when='config', item=None,
): ):
return (yield) return (yield)
@pytest.hookimpl(wrapper=True) @pytest.hookimpl(wrapper=True)
def pytest_load_initial_conftests( def pytest_load_initial_conftests(
early_config: "Config", early_config: Config,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
with catch_warnings_for_item( with catch_warnings_for_item(
config=early_config, ihook=early_config.hook, when="config", item=None config=early_config, ihook=early_config.hook, when='config', item=None,
): ):
return (yield) return (yield)

View file

@ -3,4 +3,3 @@ Generator: bdist_wheel (0.37.1)
Root-Is-Purelib: true Root-Is-Purelib: true
Tag: py2-none-any Tag: py2-none-any
Tag: py3-none-any Tag: py3-none-any

View file

@ -2,4 +2,3 @@ Wheel-Version: 1.0
Generator: bdist_wheel (0.43.0) Generator: bdist_wheel (0.43.0)
Root-Is-Purelib: false Root-Is-Purelib: false
Tag: cp310-cp310-macosx_10_9_x86_64 Tag: cp310-cp310-macosx_10_9_x86_64

View file

@ -1,6 +1,5 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
""" """
Code coverage measurement for Python. Code coverage measurement for Python.
@ -8,31 +7,22 @@ Ned Batchelder
https://coverage.readthedocs.io https://coverage.readthedocs.io
""" """
from __future__ import annotations from __future__ import annotations
from coverage.control import Coverage as Coverage
from coverage.control import process_startup as process_startup
from coverage.data import CoverageData as CoverageData
from coverage.exceptions import CoverageException as CoverageException
from coverage.plugin import CoveragePlugin as CoveragePlugin
from coverage.plugin import FileReporter as FileReporter
from coverage.plugin import FileTracer as FileTracer
from coverage.version import __version__ as __version__
from coverage.version import version_info as version_info
# mypy's convention is that "import as" names are public from the module. # mypy's convention is that "import as" names are public from the module.
# We import names as themselves to indicate that. Pylint sees it as pointless, # We import names as themselves to indicate that. Pylint sees it as pointless,
# so disable its warning. # so disable its warning.
# pylint: disable=useless-import-alias # pylint: disable=useless-import-alias
from coverage.version import (
__version__ as __version__,
version_info as version_info,
)
from coverage.control import (
Coverage as Coverage,
process_startup as process_startup,
)
from coverage.data import CoverageData as CoverageData
from coverage.exceptions import CoverageException as CoverageException
from coverage.plugin import (
CoveragePlugin as CoveragePlugin,
FileReporter as FileReporter,
FileTracer as FileTracer,
)
# Backward compatibility. # Backward compatibility.
coverage = Coverage coverage = Coverage

View file

@ -1,10 +1,9 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Coverage.py's main entry point.""" """Coverage.py's main entry point."""
from __future__ import annotations from __future__ import annotations
import sys import sys
from coverage.cmdline import main from coverage.cmdline import main
sys.exit(main()) sys.exit(main())

View file

@ -1,17 +1,16 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Source file annotation for coverage.py.""" """Source file annotation for coverage.py."""
from __future__ import annotations from __future__ import annotations
import os import os
import re import re
from typing import Iterable
from typing import Iterable, TYPE_CHECKING from typing import TYPE_CHECKING
from coverage.files import flat_rootname from coverage.files import flat_rootname
from coverage.misc import ensure_dir, isolate_module from coverage.misc import ensure_dir
from coverage.misc import isolate_module
from coverage.plugin import FileReporter from coverage.plugin import FileReporter
from coverage.report_core import get_analysis_to_report from coverage.report_core import get_analysis_to_report
from coverage.results import Analysis from coverage.results import Analysis
@ -50,8 +49,8 @@ class AnnotateReporter:
self.config = self.coverage.config self.config = self.coverage.config
self.directory: str | None = None self.directory: str | None = None
blank_re = re.compile(r"\s*(#|$)") blank_re = re.compile(r'\s*(#|$)')
else_re = re.compile(r"\s*else\s*:\s*(#|$)") else_re = re.compile(r'\s*else\s*:\s*(#|$)')
def report(self, morfs: Iterable[TMorf] | None, directory: str | None = None) -> None: def report(self, morfs: Iterable[TMorf] | None, directory: str | None = None) -> None:
"""Run the report. """Run the report.
@ -77,13 +76,13 @@ class AnnotateReporter:
if self.directory: if self.directory:
ensure_dir(self.directory) ensure_dir(self.directory)
dest_file = os.path.join(self.directory, flat_rootname(fr.relative_filename())) dest_file = os.path.join(self.directory, flat_rootname(fr.relative_filename()))
if dest_file.endswith("_py"): if dest_file.endswith('_py'):
dest_file = dest_file[:-3] + ".py" dest_file = dest_file[:-3] + '.py'
dest_file += ",cover" dest_file += ',cover'
else: else:
dest_file = fr.filename + ",cover" dest_file = fr.filename + ',cover'
with open(dest_file, "w", encoding="utf-8") as dest: with open(dest_file, 'w', encoding='utf-8') as dest:
i = j = 0 i = j = 0
covered = True covered = True
source = fr.source() source = fr.source()
@ -95,20 +94,20 @@ class AnnotateReporter:
if i < len(statements) and statements[i] == lineno: if i < len(statements) and statements[i] == lineno:
covered = j >= len(missing) or missing[j] > lineno covered = j >= len(missing) or missing[j] > lineno
if self.blank_re.match(line): if self.blank_re.match(line):
dest.write(" ") dest.write(' ')
elif self.else_re.match(line): elif self.else_re.match(line):
# Special logic for lines containing only "else:". # Special logic for lines containing only "else:".
if j >= len(missing): if j >= len(missing):
dest.write("> ") dest.write('> ')
elif statements[i] == missing[j]: elif statements[i] == missing[j]:
dest.write("! ") dest.write('! ')
else: else:
dest.write("> ") dest.write('> ')
elif lineno in excluded: elif lineno in excluded:
dest.write("- ") dest.write('- ')
elif covered: elif covered:
dest.write("> ") dest.write('> ')
else: else:
dest.write("! ") dest.write('! ')
dest.write(line) dest.write(line)

View file

@ -1,8 +1,6 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Bytecode manipulation for coverage.py""" """Bytecode manipulation for coverage.py"""
from __future__ import annotations from __future__ import annotations
from types import CodeType from types import CodeType

View file

@ -1,20 +1,18 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Command-line support for coverage.py.""" """Command-line support for coverage.py."""
from __future__ import annotations from __future__ import annotations
import glob import glob
import optparse # pylint: disable=deprecated-module import optparse # pylint: disable=deprecated-module
import os
import os.path import os.path
import shlex import shlex
import sys import sys
import textwrap import textwrap
import traceback import traceback
from typing import Any
from typing import cast, Any, NoReturn from typing import cast
from typing import NoReturn
import coverage import coverage
from coverage import Coverage from coverage import Coverage
@ -22,16 +20,23 @@ from coverage import env
from coverage.collector import HAS_CTRACER from coverage.collector import HAS_CTRACER
from coverage.config import CoverageConfig from coverage.config import CoverageConfig
from coverage.control import DEFAULT_DATAFILE from coverage.control import DEFAULT_DATAFILE
from coverage.data import combinable_files, debug_data_file from coverage.data import combinable_files
from coverage.debug import info_header, short_stack, write_formatted_info from coverage.data import debug_data_file
from coverage.exceptions import _BaseCoverageException, _ExceptionDuringRun, NoSource from coverage.debug import info_header
from coverage.debug import short_stack
from coverage.debug import write_formatted_info
from coverage.exceptions import _BaseCoverageException
from coverage.exceptions import _ExceptionDuringRun
from coverage.exceptions import NoSource
from coverage.execfile import PyRunner from coverage.execfile import PyRunner
from coverage.results import Numbers, should_fail_under from coverage.results import Numbers
from coverage.results import should_fail_under
from coverage.version import __url__ from coverage.version import __url__
# When adding to this file, alphabetization is important. Look for # When adding to this file, alphabetization is important. Look for
# "alphabetize" comments throughout. # "alphabetize" comments throughout.
class Opts: class Opts:
"""A namespace class for individual options we'll build parsers from.""" """A namespace class for individual options we'll build parsers from."""
@ -39,193 +44,193 @@ class Opts:
# appears on the command line. # appears on the command line.
append = optparse.make_option( append = optparse.make_option(
"-a", "--append", action="store_true", '-a', '--append', action='store_true',
help="Append coverage data to .coverage, otherwise it starts clean each time.", help='Append coverage data to .coverage, otherwise it starts clean each time.',
) )
branch = optparse.make_option( branch = optparse.make_option(
"", "--branch", action="store_true", '', '--branch', action='store_true',
help="Measure branch coverage in addition to statement coverage.", help='Measure branch coverage in addition to statement coverage.',
) )
concurrency = optparse.make_option( concurrency = optparse.make_option(
"", "--concurrency", action="store", metavar="LIBS", '', '--concurrency', action='store', metavar='LIBS',
help=( help=(
"Properly measure code using a concurrency library. " + 'Properly measure code using a concurrency library. ' +
"Valid values are: {}, or a comma-list of them." 'Valid values are: {}, or a comma-list of them.'
).format(", ".join(sorted(CoverageConfig.CONCURRENCY_CHOICES))), ).format(', '.join(sorted(CoverageConfig.CONCURRENCY_CHOICES))),
) )
context = optparse.make_option( context = optparse.make_option(
"", "--context", action="store", metavar="LABEL", '', '--context', action='store', metavar='LABEL',
help="The context label to record for this coverage run.", help='The context label to record for this coverage run.',
) )
contexts = optparse.make_option( contexts = optparse.make_option(
"", "--contexts", action="store", metavar="REGEX1,REGEX2,...", '', '--contexts', action='store', metavar='REGEX1,REGEX2,...',
help=( help=(
"Only display data from lines covered in the given contexts. " + 'Only display data from lines covered in the given contexts. ' +
"Accepts Python regexes, which must be quoted." 'Accepts Python regexes, which must be quoted.'
), ),
) )
datafile = optparse.make_option( datafile = optparse.make_option(
"", "--data-file", action="store", metavar="DATAFILE", '', '--data-file', action='store', metavar='DATAFILE',
help=( help=(
"Base name of the data files to operate on. " + 'Base name of the data files to operate on. ' +
"Defaults to '.coverage'. [env: COVERAGE_FILE]" "Defaults to '.coverage'. [env: COVERAGE_FILE]"
), ),
) )
datafle_input = optparse.make_option( datafle_input = optparse.make_option(
"", "--data-file", action="store", metavar="INFILE", '', '--data-file', action='store', metavar='INFILE',
help=( help=(
"Read coverage data for report generation from this file. " + 'Read coverage data for report generation from this file. ' +
"Defaults to '.coverage'. [env: COVERAGE_FILE]" "Defaults to '.coverage'. [env: COVERAGE_FILE]"
), ),
) )
datafile_output = optparse.make_option( datafile_output = optparse.make_option(
"", "--data-file", action="store", metavar="OUTFILE", '', '--data-file', action='store', metavar='OUTFILE',
help=( help=(
"Write the recorded coverage data to this file. " + 'Write the recorded coverage data to this file. ' +
"Defaults to '.coverage'. [env: COVERAGE_FILE]" "Defaults to '.coverage'. [env: COVERAGE_FILE]"
), ),
) )
debug = optparse.make_option( debug = optparse.make_option(
"", "--debug", action="store", metavar="OPTS", '', '--debug', action='store', metavar='OPTS',
help="Debug options, separated by commas. [env: COVERAGE_DEBUG]", help='Debug options, separated by commas. [env: COVERAGE_DEBUG]',
) )
directory = optparse.make_option( directory = optparse.make_option(
"-d", "--directory", action="store", metavar="DIR", '-d', '--directory', action='store', metavar='DIR',
help="Write the output files to DIR.", help='Write the output files to DIR.',
) )
fail_under = optparse.make_option( fail_under = optparse.make_option(
"", "--fail-under", action="store", metavar="MIN", type="float", '', '--fail-under', action='store', metavar='MIN', type='float',
help="Exit with a status of 2 if the total coverage is less than MIN.", help='Exit with a status of 2 if the total coverage is less than MIN.',
) )
format = optparse.make_option( format = optparse.make_option(
"", "--format", action="store", metavar="FORMAT", '', '--format', action='store', metavar='FORMAT',
help="Output format, either text (default), markdown, or total.", help='Output format, either text (default), markdown, or total.',
) )
help = optparse.make_option( help = optparse.make_option(
"-h", "--help", action="store_true", '-h', '--help', action='store_true',
help="Get help on this command.", help='Get help on this command.',
) )
ignore_errors = optparse.make_option( ignore_errors = optparse.make_option(
"-i", "--ignore-errors", action="store_true", '-i', '--ignore-errors', action='store_true',
help="Ignore errors while reading source files.", help='Ignore errors while reading source files.',
) )
include = optparse.make_option( include = optparse.make_option(
"", "--include", action="store", metavar="PAT1,PAT2,...", '', '--include', action='store', metavar='PAT1,PAT2,...',
help=( help=(
"Include only files whose paths match one of these patterns. " + 'Include only files whose paths match one of these patterns. ' +
"Accepts shell-style wildcards, which must be quoted." 'Accepts shell-style wildcards, which must be quoted.'
), ),
) )
keep = optparse.make_option( keep = optparse.make_option(
"", "--keep", action="store_true", '', '--keep', action='store_true',
help="Keep original coverage files, otherwise they are deleted.", help='Keep original coverage files, otherwise they are deleted.',
) )
pylib = optparse.make_option( pylib = optparse.make_option(
"-L", "--pylib", action="store_true", '-L', '--pylib', action='store_true',
help=( help=(
"Measure coverage even inside the Python installed library, " + 'Measure coverage even inside the Python installed library, ' +
"which isn't done by default." "which isn't done by default."
), ),
) )
show_missing = optparse.make_option( show_missing = optparse.make_option(
"-m", "--show-missing", action="store_true", '-m', '--show-missing', action='store_true',
help="Show line numbers of statements in each module that weren't executed.", help="Show line numbers of statements in each module that weren't executed.",
) )
module = optparse.make_option( module = optparse.make_option(
"-m", "--module", action="store_true", '-m', '--module', action='store_true',
help=( help=(
"<pyfile> is an importable Python module, not a script path, " + '<pyfile> is an importable Python module, not a script path, ' +
"to be run as 'python -m' would run it." "to be run as 'python -m' would run it."
), ),
) )
omit = optparse.make_option( omit = optparse.make_option(
"", "--omit", action="store", metavar="PAT1,PAT2,...", '', '--omit', action='store', metavar='PAT1,PAT2,...',
help=( help=(
"Omit files whose paths match one of these patterns. " + 'Omit files whose paths match one of these patterns. ' +
"Accepts shell-style wildcards, which must be quoted." 'Accepts shell-style wildcards, which must be quoted.'
), ),
) )
output_xml = optparse.make_option( output_xml = optparse.make_option(
"-o", "", action="store", dest="outfile", metavar="OUTFILE", '-o', '', action='store', dest='outfile', metavar='OUTFILE',
help="Write the XML report to this file. Defaults to 'coverage.xml'", help="Write the XML report to this file. Defaults to 'coverage.xml'",
) )
output_json = optparse.make_option( output_json = optparse.make_option(
"-o", "", action="store", dest="outfile", metavar="OUTFILE", '-o', '', action='store', dest='outfile', metavar='OUTFILE',
help="Write the JSON report to this file. Defaults to 'coverage.json'", help="Write the JSON report to this file. Defaults to 'coverage.json'",
) )
output_lcov = optparse.make_option( output_lcov = optparse.make_option(
"-o", "", action="store", dest="outfile", metavar="OUTFILE", '-o', '', action='store', dest='outfile', metavar='OUTFILE',
help="Write the LCOV report to this file. Defaults to 'coverage.lcov'", help="Write the LCOV report to this file. Defaults to 'coverage.lcov'",
) )
json_pretty_print = optparse.make_option( json_pretty_print = optparse.make_option(
"", "--pretty-print", action="store_true", '', '--pretty-print', action='store_true',
help="Format the JSON for human readers.", help='Format the JSON for human readers.',
) )
parallel_mode = optparse.make_option( parallel_mode = optparse.make_option(
"-p", "--parallel-mode", action="store_true", '-p', '--parallel-mode', action='store_true',
help=( help=(
"Append the machine name, process id and random number to the " + 'Append the machine name, process id and random number to the ' +
"data file name to simplify collecting data from " + 'data file name to simplify collecting data from ' +
"many processes." 'many processes.'
), ),
) )
precision = optparse.make_option( precision = optparse.make_option(
"", "--precision", action="store", metavar="N", type=int, '', '--precision', action='store', metavar='N', type=int,
help=( help=(
"Number of digits after the decimal point to display for " + 'Number of digits after the decimal point to display for ' +
"reported coverage percentages." 'reported coverage percentages.'
), ),
) )
quiet = optparse.make_option( quiet = optparse.make_option(
"-q", "--quiet", action="store_true", '-q', '--quiet', action='store_true',
help="Don't print messages about what is happening.", help="Don't print messages about what is happening.",
) )
rcfile = optparse.make_option( rcfile = optparse.make_option(
"", "--rcfile", action="store", '', '--rcfile', action='store',
help=( help=(
"Specify configuration file. " + 'Specify configuration file. ' +
"By default '.coveragerc', 'setup.cfg', 'tox.ini', and " + "By default '.coveragerc', 'setup.cfg', 'tox.ini', and " +
"'pyproject.toml' are tried. [env: COVERAGE_RCFILE]" "'pyproject.toml' are tried. [env: COVERAGE_RCFILE]"
), ),
) )
show_contexts = optparse.make_option( show_contexts = optparse.make_option(
"--show-contexts", action="store_true", '--show-contexts', action='store_true',
help="Show contexts for covered lines.", help='Show contexts for covered lines.',
) )
skip_covered = optparse.make_option( skip_covered = optparse.make_option(
"--skip-covered", action="store_true", '--skip-covered', action='store_true',
help="Skip files with 100% coverage.", help='Skip files with 100% coverage.',
) )
no_skip_covered = optparse.make_option( no_skip_covered = optparse.make_option(
"--no-skip-covered", action="store_false", dest="skip_covered", '--no-skip-covered', action='store_false', dest='skip_covered',
help="Disable --skip-covered.", help='Disable --skip-covered.',
) )
skip_empty = optparse.make_option( skip_empty = optparse.make_option(
"--skip-empty", action="store_true", '--skip-empty', action='store_true',
help="Skip files with no code.", help='Skip files with no code.',
) )
sort = optparse.make_option( sort = optparse.make_option(
"--sort", action="store", metavar="COLUMN", '--sort', action='store', metavar='COLUMN',
help=( help=(
"Sort the report by the named column: name, stmts, miss, branch, brpart, or cover. " + 'Sort the report by the named column: name, stmts, miss, branch, brpart, or cover. ' +
"Default is name." 'Default is name.'
), ),
) )
source = optparse.make_option( source = optparse.make_option(
"", "--source", action="store", metavar="SRC1,SRC2,...", '', '--source', action='store', metavar='SRC1,SRC2,...',
help="A list of directories or importable names of code to measure.", help='A list of directories or importable names of code to measure.',
) )
timid = optparse.make_option( timid = optparse.make_option(
"", "--timid", action="store_true", '', '--timid', action='store_true',
help="Use the slower Python trace function core.", help='Use the slower Python trace function core.',
) )
title = optparse.make_option( title = optparse.make_option(
"", "--title", action="store", metavar="TITLE", '', '--title', action='store', metavar='TITLE',
help="A text string to use as the title on the HTML.", help='A text string to use as the title on the HTML.',
) )
version = optparse.make_option( version = optparse.make_option(
"", "--version", action="store_true", '', '--version', action='store_true',
help="Display version information and exit.", help='Display version information and exit.',
) )
@ -238,7 +243,7 @@ class CoverageOptionParser(optparse.OptionParser):
""" """
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["add_help_option"] = False kwargs['add_help_option'] = False
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.set_defaults( self.set_defaults(
# Keep these arguments alphabetized by their names. # Keep these arguments alphabetized by their names.
@ -330,7 +335,7 @@ class CmdOptionParser(CoverageOptionParser):
""" """
if usage: if usage:
usage = "%prog " + usage usage = '%prog ' + usage
super().__init__( super().__init__(
usage=usage, usage=usage,
description=description, description=description,
@ -342,7 +347,7 @@ class CmdOptionParser(CoverageOptionParser):
def __eq__(self, other: str) -> bool: # type: ignore[override] def __eq__(self, other: str) -> bool: # type: ignore[override]
# A convenience equality, so that I can put strings in unit test # A convenience equality, so that I can put strings in unit test
# results, and they will compare equal to objects. # results, and they will compare equal to objects.
return (other == f"<CmdOptionParser:{self.cmd}>") return (other == f'<CmdOptionParser:{self.cmd}>')
__hash__ = None # type: ignore[assignment] __hash__ = None # type: ignore[assignment]
@ -351,7 +356,7 @@ class CmdOptionParser(CoverageOptionParser):
program_name = super().get_prog_name() program_name = super().get_prog_name()
# Include the sub-command for this parser as part of the command. # Include the sub-command for this parser as part of the command.
return f"{program_name} {self.cmd}" return f'{program_name} {self.cmd}'
# In lists of Opts, keep them alphabetized by the option names as they appear # In lists of Opts, keep them alphabetized by the option names as they appear
# on the command line, since these lists determine the order of the options in # on the command line, since these lists determine the order of the options in
@ -359,6 +364,7 @@ class CmdOptionParser(CoverageOptionParser):
# #
# In COMMANDS, keep the keys (command names) alphabetized. # In COMMANDS, keep the keys (command names) alphabetized.
GLOBAL_ARGS = [ GLOBAL_ARGS = [
Opts.debug, Opts.debug,
Opts.help, Opts.help,
@ -366,72 +372,72 @@ GLOBAL_ARGS = [
] ]
COMMANDS = { COMMANDS = {
"annotate": CmdOptionParser( 'annotate': CmdOptionParser(
"annotate", 'annotate',
[ [
Opts.directory, Opts.directory,
Opts.datafle_input, Opts.datafle_input,
Opts.ignore_errors, Opts.ignore_errors,
Opts.include, Opts.include,
Opts.omit, Opts.omit,
] + GLOBAL_ARGS, ] + GLOBAL_ARGS,
usage="[options] [modules]", usage='[options] [modules]',
description=( description=(
"Make annotated copies of the given files, marking statements that are executed " + 'Make annotated copies of the given files, marking statements that are executed ' +
"with > and statements that are missed with !." 'with > and statements that are missed with !.'
), ),
), ),
"combine": CmdOptionParser( 'combine': CmdOptionParser(
"combine", 'combine',
[ [
Opts.append, Opts.append,
Opts.datafile, Opts.datafile,
Opts.keep, Opts.keep,
Opts.quiet, Opts.quiet,
] + GLOBAL_ARGS, ] + GLOBAL_ARGS,
usage="[options] <path1> <path2> ... <pathN>", usage='[options] <path1> <path2> ... <pathN>',
description=( description=(
"Combine data from multiple coverage files. " + 'Combine data from multiple coverage files. ' +
"The combined results are written to a single " + 'The combined results are written to a single ' +
"file representing the union of the data. The positional " + 'file representing the union of the data. The positional ' +
"arguments are data files or directories containing data files. " + 'arguments are data files or directories containing data files. ' +
"If no paths are provided, data files in the default data file's " + "If no paths are provided, data files in the default data file's " +
"directory are combined." 'directory are combined.'
), ),
), ),
"debug": CmdOptionParser( 'debug': CmdOptionParser(
"debug", GLOBAL_ARGS, 'debug', GLOBAL_ARGS,
usage="<topic>", usage='<topic>',
description=( description=(
"Display information about the internals of coverage.py, " + 'Display information about the internals of coverage.py, ' +
"for diagnosing problems. " + 'for diagnosing problems. ' +
"Topics are: " + 'Topics are: ' +
"'data' to show a summary of the collected data; " + "'data' to show a summary of the collected data; " +
"'sys' to show installation information; " + "'sys' to show installation information; " +
"'config' to show the configuration; " + "'config' to show the configuration; " +
"'premain' to show what is calling coverage; " + "'premain' to show what is calling coverage; " +
"'pybehave' to show internal flags describing Python behavior." "'pybehave' to show internal flags describing Python behavior."
), ),
), ),
"erase": CmdOptionParser( 'erase': CmdOptionParser(
"erase", 'erase',
[ [
Opts.datafile, Opts.datafile,
] + GLOBAL_ARGS, ] + GLOBAL_ARGS,
description="Erase previously collected coverage data.", description='Erase previously collected coverage data.',
), ),
"help": CmdOptionParser( 'help': CmdOptionParser(
"help", GLOBAL_ARGS, 'help', GLOBAL_ARGS,
usage="[command]", usage='[command]',
description="Describe how to use coverage.py", description='Describe how to use coverage.py',
), ),
"html": CmdOptionParser( 'html': CmdOptionParser(
"html", 'html',
[ [
Opts.contexts, Opts.contexts,
Opts.directory, Opts.directory,
@ -447,17 +453,17 @@ COMMANDS = {
Opts.no_skip_covered, Opts.no_skip_covered,
Opts.skip_empty, Opts.skip_empty,
Opts.title, Opts.title,
] + GLOBAL_ARGS, ] + GLOBAL_ARGS,
usage="[options] [modules]", usage='[options] [modules]',
description=( description=(
"Create an HTML report of the coverage of the files. " + 'Create an HTML report of the coverage of the files. ' +
"Each file gets its own page, with the source decorated to show " + 'Each file gets its own page, with the source decorated to show ' +
"executed, excluded, and missed lines." 'executed, excluded, and missed lines.'
), ),
), ),
"json": CmdOptionParser( 'json': CmdOptionParser(
"json", 'json',
[ [
Opts.contexts, Opts.contexts,
Opts.datafle_input, Opts.datafle_input,
@ -469,13 +475,13 @@ COMMANDS = {
Opts.json_pretty_print, Opts.json_pretty_print,
Opts.quiet, Opts.quiet,
Opts.show_contexts, Opts.show_contexts,
] + GLOBAL_ARGS, ] + GLOBAL_ARGS,
usage="[options] [modules]", usage='[options] [modules]',
description="Generate a JSON report of coverage results.", description='Generate a JSON report of coverage results.',
), ),
"lcov": CmdOptionParser( 'lcov': CmdOptionParser(
"lcov", 'lcov',
[ [
Opts.datafle_input, Opts.datafle_input,
Opts.fail_under, Opts.fail_under,
@ -484,13 +490,13 @@ COMMANDS = {
Opts.output_lcov, Opts.output_lcov,
Opts.omit, Opts.omit,
Opts.quiet, Opts.quiet,
] + GLOBAL_ARGS, ] + GLOBAL_ARGS,
usage="[options] [modules]", usage='[options] [modules]',
description="Generate an LCOV report of coverage results.", description='Generate an LCOV report of coverage results.',
), ),
"report": CmdOptionParser( 'report': CmdOptionParser(
"report", 'report',
[ [
Opts.contexts, Opts.contexts,
Opts.datafle_input, Opts.datafle_input,
@ -505,13 +511,13 @@ COMMANDS = {
Opts.skip_covered, Opts.skip_covered,
Opts.no_skip_covered, Opts.no_skip_covered,
Opts.skip_empty, Opts.skip_empty,
] + GLOBAL_ARGS, ] + GLOBAL_ARGS,
usage="[options] [modules]", usage='[options] [modules]',
description="Report coverage statistics on modules.", description='Report coverage statistics on modules.',
), ),
"run": CmdOptionParser( 'run': CmdOptionParser(
"run", 'run',
[ [
Opts.append, Opts.append,
Opts.branch, Opts.branch,
@ -525,13 +531,13 @@ COMMANDS = {
Opts.parallel_mode, Opts.parallel_mode,
Opts.source, Opts.source,
Opts.timid, Opts.timid,
] + GLOBAL_ARGS, ] + GLOBAL_ARGS,
usage="[options] <pyfile> [program options]", usage='[options] <pyfile> [program options]',
description="Run a Python program, measuring code execution.", description='Run a Python program, measuring code execution.',
), ),
"xml": CmdOptionParser( 'xml': CmdOptionParser(
"xml", 'xml',
[ [
Opts.datafle_input, Opts.datafle_input,
Opts.fail_under, Opts.fail_under,
@ -541,9 +547,9 @@ COMMANDS = {
Opts.output_xml, Opts.output_xml,
Opts.quiet, Opts.quiet,
Opts.skip_empty, Opts.skip_empty,
] + GLOBAL_ARGS, ] + GLOBAL_ARGS,
usage="[options] [modules]", usage='[options] [modules]',
description="Generate an XML report of coverage results.", description='Generate an XML report of coverage results.',
), ),
} }
@ -557,7 +563,7 @@ def show_help(
assert error or topic or parser assert error or topic or parser
program_path = sys.argv[0] program_path = sys.argv[0]
if program_path.endswith(os.path.sep + "__main__.py"): if program_path.endswith(os.path.sep + '__main__.py'):
# The path is the main module of a package; get that path instead. # The path is the main module of a package; get that path instead.
program_path = os.path.dirname(program_path) program_path = os.path.dirname(program_path)
program_name = os.path.basename(program_path) program_name = os.path.basename(program_path)
@ -567,17 +573,17 @@ def show_help(
# invoke coverage-script.py, coverage3-script.py, and # invoke coverage-script.py, coverage3-script.py, and
# coverage-3.5-script.py. argv[0] is the .py file, but we want to # coverage-3.5-script.py. argv[0] is the .py file, but we want to
# get back to the original form. # get back to the original form.
auto_suffix = "-script.py" auto_suffix = '-script.py'
if program_name.endswith(auto_suffix): if program_name.endswith(auto_suffix):
program_name = program_name[:-len(auto_suffix)] program_name = program_name[:-len(auto_suffix)]
help_params = dict(coverage.__dict__) help_params = dict(coverage.__dict__)
help_params["__url__"] = __url__ help_params['__url__'] = __url__
help_params["program_name"] = program_name help_params['program_name'] = program_name
if HAS_CTRACER: if HAS_CTRACER:
help_params["extension_modifier"] = "with C extension" help_params['extension_modifier'] = 'with C extension'
else: else:
help_params["extension_modifier"] = "without C extension" help_params['extension_modifier'] = 'without C extension'
if error: if error:
print(error, file=sys.stderr) print(error, file=sys.stderr)
@ -587,12 +593,12 @@ def show_help(
print() print()
else: else:
assert topic is not None assert topic is not None
help_msg = textwrap.dedent(HELP_TOPICS.get(topic, "")).strip() help_msg = textwrap.dedent(HELP_TOPICS.get(topic, '')).strip()
if help_msg: if help_msg:
print(help_msg.format(**help_params)) print(help_msg.format(**help_params))
else: else:
print(f"Don't know topic {topic!r}") print(f"Don't know topic {topic!r}")
print("Full documentation is at {__url__}".format(**help_params)) print('Full documentation is at {__url__}'.format(**help_params))
OK, ERR, FAIL_UNDER = 0, 1, 2 OK, ERR, FAIL_UNDER = 0, 1, 2
@ -615,19 +621,19 @@ class CoverageScript:
""" """
# Collect the command-line options. # Collect the command-line options.
if not argv: if not argv:
show_help(topic="minimum_help") show_help(topic='minimum_help')
return OK return OK
# The command syntax we parse depends on the first argument. Global # The command syntax we parse depends on the first argument. Global
# switch syntax always starts with an option. # switch syntax always starts with an option.
parser: optparse.OptionParser | None parser: optparse.OptionParser | None
self.global_option = argv[0].startswith("-") self.global_option = argv[0].startswith('-')
if self.global_option: if self.global_option:
parser = GlobalOptionParser() parser = GlobalOptionParser()
else: else:
parser = COMMANDS.get(argv[0]) parser = COMMANDS.get(argv[0])
if not parser: if not parser:
show_help(f"Unknown command: {argv[0]!r}") show_help(f'Unknown command: {argv[0]!r}')
return ERR return ERR
argv = argv[1:] argv = argv[1:]
@ -648,7 +654,7 @@ class CoverageScript:
contexts = unshell_list(options.contexts) contexts = unshell_list(options.contexts)
if options.concurrency is not None: if options.concurrency is not None:
concurrency = options.concurrency.split(",") concurrency = options.concurrency.split(',')
else: else:
concurrency = None concurrency = None
@ -670,17 +676,17 @@ class CoverageScript:
messages=not options.quiet, messages=not options.quiet,
) )
if options.action == "debug": if options.action == 'debug':
return self.do_debug(args) return self.do_debug(args)
elif options.action == "erase": elif options.action == 'erase':
self.coverage.erase() self.coverage.erase()
return OK return OK
elif options.action == "run": elif options.action == 'run':
return self.do_run(options, args) return self.do_run(options, args)
elif options.action == "combine": elif options.action == 'combine':
if options.append: if options.append:
self.coverage.load() self.coverage.load()
data_paths = args or None data_paths = args or None
@ -699,12 +705,12 @@ class CoverageScript:
# We need to be able to import from the current directory, because # We need to be able to import from the current directory, because
# plugins may try to, for example, to read Django settings. # plugins may try to, for example, to read Django settings.
sys.path.insert(0, "") sys.path.insert(0, '')
self.coverage.load() self.coverage.load()
total = None total = None
if options.action == "report": if options.action == 'report':
total = self.coverage.report( total = self.coverage.report(
precision=options.precision, precision=options.precision,
show_missing=options.show_missing, show_missing=options.show_missing,
@ -714,9 +720,9 @@ class CoverageScript:
output_format=options.format, output_format=options.format,
**report_args, **report_args,
) )
elif options.action == "annotate": elif options.action == 'annotate':
self.coverage.annotate(directory=options.directory, **report_args) self.coverage.annotate(directory=options.directory, **report_args)
elif options.action == "html": elif options.action == 'html':
total = self.coverage.html_report( total = self.coverage.html_report(
directory=options.directory, directory=options.directory,
precision=options.precision, precision=options.precision,
@ -726,20 +732,20 @@ class CoverageScript:
title=options.title, title=options.title,
**report_args, **report_args,
) )
elif options.action == "xml": elif options.action == 'xml':
total = self.coverage.xml_report( total = self.coverage.xml_report(
outfile=options.outfile, outfile=options.outfile,
skip_empty=options.skip_empty, skip_empty=options.skip_empty,
**report_args, **report_args,
) )
elif options.action == "json": elif options.action == 'json':
total = self.coverage.json_report( total = self.coverage.json_report(
outfile=options.outfile, outfile=options.outfile,
pretty_print=options.pretty_print, pretty_print=options.pretty_print,
show_contexts=options.show_contexts, show_contexts=options.show_contexts,
**report_args, **report_args,
) )
elif options.action == "lcov": elif options.action == 'lcov':
total = self.coverage.lcov_report( total = self.coverage.lcov_report(
outfile=options.outfile, outfile=options.outfile,
**report_args, **report_args,
@ -752,19 +758,19 @@ class CoverageScript:
# Apply the command line fail-under options, and then use the config # Apply the command line fail-under options, and then use the config
# value, so we can get fail_under from the config file. # value, so we can get fail_under from the config file.
if options.fail_under is not None: if options.fail_under is not None:
self.coverage.set_option("report:fail_under", options.fail_under) self.coverage.set_option('report:fail_under', options.fail_under)
if options.precision is not None: if options.precision is not None:
self.coverage.set_option("report:precision", options.precision) self.coverage.set_option('report:precision', options.precision)
fail_under = cast(float, self.coverage.get_option("report:fail_under")) fail_under = cast(float, self.coverage.get_option('report:fail_under'))
precision = cast(int, self.coverage.get_option("report:precision")) precision = cast(int, self.coverage.get_option('report:precision'))
if should_fail_under(total, fail_under, precision): if should_fail_under(total, fail_under, precision):
msg = "total of {total} is less than fail-under={fail_under:.{p}f}".format( msg = 'total of {total} is less than fail-under={fail_under:.{p}f}'.format(
total=Numbers(precision=precision).display_covered(total), total=Numbers(precision=precision).display_covered(total),
fail_under=fail_under, fail_under=fail_under,
p=precision, p=precision,
) )
print("Coverage failure:", msg) print('Coverage failure:', msg)
return FAIL_UNDER return FAIL_UNDER
return OK return OK
@ -783,12 +789,12 @@ class CoverageScript:
# Handle help. # Handle help.
if options.help: if options.help:
if self.global_option: if self.global_option:
show_help(topic="help") show_help(topic='help')
else: else:
show_help(parser=parser) show_help(parser=parser)
return True return True
if options.action == "help": if options.action == 'help':
if args: if args:
for a in args: for a in args:
parser_maybe = COMMANDS.get(a) parser_maybe = COMMANDS.get(a)
@ -797,12 +803,12 @@ class CoverageScript:
else: else:
show_help(topic=a) show_help(topic=a)
else: else:
show_help(topic="help") show_help(topic='help')
return True return True
# Handle version. # Handle version.
if options.version: if options.version:
show_help(topic="version") show_help(topic='version')
return True return True
return False return False
@ -813,37 +819,37 @@ class CoverageScript:
if not args: if not args:
if options.module: if options.module:
# Specified -m with nothing else. # Specified -m with nothing else.
show_help("No module specified for -m") show_help('No module specified for -m')
return ERR return ERR
command_line = cast(str, self.coverage.get_option("run:command_line")) command_line = cast(str, self.coverage.get_option('run:command_line'))
if command_line is not None: if command_line is not None:
args = shlex.split(command_line) args = shlex.split(command_line)
if args and args[0] in {"-m", "--module"}: if args and args[0] in {'-m', '--module'}:
options.module = True options.module = True
args = args[1:] args = args[1:]
if not args: if not args:
show_help("Nothing to do.") show_help('Nothing to do.')
return ERR return ERR
if options.append and self.coverage.get_option("run:parallel"): if options.append and self.coverage.get_option('run:parallel'):
show_help("Can't append to data files in parallel mode.") show_help("Can't append to data files in parallel mode.")
return ERR return ERR
if options.concurrency == "multiprocessing": if options.concurrency == 'multiprocessing':
# Can't set other run-affecting command line options with # Can't set other run-affecting command line options with
# multiprocessing. # multiprocessing.
for opt_name in ["branch", "include", "omit", "pylib", "source", "timid"]: for opt_name in ['branch', 'include', 'omit', 'pylib', 'source', 'timid']:
# As it happens, all of these options have no default, meaning # As it happens, all of these options have no default, meaning
# they will be None if they have not been specified. # they will be None if they have not been specified.
if getattr(options, opt_name) is not None: if getattr(options, opt_name) is not None:
show_help( show_help(
"Options affecting multiprocessing must only be specified " + 'Options affecting multiprocessing must only be specified ' +
"in a configuration file.\n" + 'in a configuration file.\n' +
f"Remove --{opt_name} from the command line.", f'Remove --{opt_name} from the command line.',
) )
return ERR return ERR
os.environ["COVERAGE_RUN"] = "true" os.environ['COVERAGE_RUN'] = 'true'
runner = PyRunner(args, as_module=bool(options.module)) runner = PyRunner(args, as_module=bool(options.module))
runner.prepare() runner.prepare()
@ -870,28 +876,28 @@ class CoverageScript:
"""Implementation of 'coverage debug'.""" """Implementation of 'coverage debug'."""
if not args: if not args:
show_help("What information would you like: config, data, sys, premain, pybehave?") show_help('What information would you like: config, data, sys, premain, pybehave?')
return ERR return ERR
if args[1:]: if args[1:]:
show_help("Only one topic at a time, please") show_help('Only one topic at a time, please')
return ERR return ERR
if args[0] == "sys": if args[0] == 'sys':
write_formatted_info(print, "sys", self.coverage.sys_info()) write_formatted_info(print, 'sys', self.coverage.sys_info())
elif args[0] == "data": elif args[0] == 'data':
print(info_header("data")) print(info_header('data'))
data_file = self.coverage.config.data_file data_file = self.coverage.config.data_file
debug_data_file(data_file) debug_data_file(data_file)
for filename in combinable_files(data_file): for filename in combinable_files(data_file):
print("-----") print('-----')
debug_data_file(filename) debug_data_file(filename)
elif args[0] == "config": elif args[0] == 'config':
write_formatted_info(print, "config", self.coverage.config.debug_info()) write_formatted_info(print, 'config', self.coverage.config.debug_info())
elif args[0] == "premain": elif args[0] == 'premain':
print(info_header("premain")) print(info_header('premain'))
print(short_stack(full=True)) print(short_stack(full=True))
elif args[0] == "pybehave": elif args[0] == 'pybehave':
write_formatted_info(print, "pybehave", env.debug_info()) write_formatted_info(print, 'pybehave', env.debug_info())
else: else:
show_help(f"Don't know what you mean by {args[0]!r}") show_help(f"Don't know what you mean by {args[0]!r}")
return ERR return ERR
@ -910,7 +916,7 @@ def unshell_list(s: str) -> list[str] | None:
# line, but (not) helpfully, the single quotes are included in the # line, but (not) helpfully, the single quotes are included in the
# argument, so we have to strip them off here. # argument, so we have to strip them off here.
s = s.strip("'") s = s.strip("'")
return s.split(",") return s.split(',')
def unglob_args(args: list[str]) -> list[str]: def unglob_args(args: list[str]) -> list[str]:
@ -918,7 +924,7 @@ def unglob_args(args: list[str]) -> list[str]:
if env.WINDOWS: if env.WINDOWS:
globbed = [] globbed = []
for arg in args: for arg in args:
if "?" in arg or "*" in arg: if '?' in arg or '*' in arg:
globbed.extend(glob.glob(arg)) globbed.extend(glob.glob(arg))
else: else:
globbed.append(arg) globbed.append(arg)
@ -927,7 +933,7 @@ def unglob_args(args: list[str]) -> list[str]:
HELP_TOPICS = { HELP_TOPICS = {
"help": """\ 'help': """\
Coverage.py, version {__version__} {extension_modifier} Coverage.py, version {__version__} {extension_modifier}
Measure, collect, and report on code coverage in Python programs. Measure, collect, and report on code coverage in Python programs.
@ -949,12 +955,12 @@ HELP_TOPICS = {
Use "{program_name} help <command>" for detailed help on any command. Use "{program_name} help <command>" for detailed help on any command.
""", """,
"minimum_help": ( 'minimum_help': (
"Code coverage for Python, version {__version__} {extension_modifier}. " + 'Code coverage for Python, version {__version__} {extension_modifier}. ' +
"Use '{program_name} help' for help." "Use '{program_name} help' for help."
), ),
"version": "Coverage.py, version {__version__} {extension_modifier}", 'version': 'Coverage.py, version {__version__} {extension_modifier}',
} }
@ -987,11 +993,12 @@ def main(argv: list[str] | None = None) -> int | None:
status = None status = None
return status return status
# Profiling using ox_profile. Install it from GitHub: # Profiling using ox_profile. Install it from GitHub:
# pip install git+https://github.com/emin63/ox_profile.git # pip install git+https://github.com/emin63/ox_profile.git
# #
# $set_env.py: COVERAGE_PROFILE - Set to use ox_profile. # $set_env.py: COVERAGE_PROFILE - Set to use ox_profile.
_profile = os.getenv("COVERAGE_PROFILE") _profile = os.getenv('COVERAGE_PROFILE')
if _profile: # pragma: debugging if _profile: # pragma: debugging
from ox_profile.core.launchers import SimpleLauncher # pylint: disable=import-error from ox_profile.core.launchers import SimpleLauncher # pylint: disable=import-error
original_main = main original_main = main
@ -1004,6 +1011,6 @@ if _profile: # pragma: debugging
try: try:
return original_main(argv) return original_main(argv)
finally: finally:
data, _ = profiler.query(re_filter="coverage", max_records=100) data, _ = profiler.query(re_filter='coverage', max_records=100)
print(profiler.show(query=data, limit=100, sep="", col="")) print(profiler.show(query=data, limit=100, sep='', col=''))
profiler.cancel() profiler.cancel()

View file

@ -1,18 +1,20 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Raw data collector for coverage.py.""" """Raw data collector for coverage.py."""
from __future__ import annotations from __future__ import annotations
import functools import functools
import os import os
import sys import sys
from types import FrameType from types import FrameType
from typing import ( from typing import Any
cast, Any, Callable, Dict, List, Mapping, Set, TypeVar, from typing import Callable
) from typing import cast
from typing import Dict
from typing import List
from typing import Mapping
from typing import Set
from typing import TypeVar
from coverage import env from coverage import env
from coverage.config import CoverageConfig from coverage.config import CoverageConfig
@ -20,13 +22,17 @@ from coverage.data import CoverageData
from coverage.debug import short_stack from coverage.debug import short_stack
from coverage.disposition import FileDisposition from coverage.disposition import FileDisposition
from coverage.exceptions import ConfigError from coverage.exceptions import ConfigError
from coverage.misc import human_sorted_items, isolate_module from coverage.misc import human_sorted_items
from coverage.misc import isolate_module
from coverage.plugin import CoveragePlugin from coverage.plugin import CoveragePlugin
from coverage.pytracer import PyTracer from coverage.pytracer import PyTracer
from coverage.sysmon import SysMonitor from coverage.sysmon import SysMonitor
from coverage.types import ( from coverage.types import TArc
TArc, TFileDisposition, TTraceData, TTraceFn, TracerCore, TWarnFn, from coverage.types import TFileDisposition
) from coverage.types import TracerCore
from coverage.types import TTraceData
from coverage.types import TTraceFn
from coverage.types import TWarnFn
os = isolate_module(os) os = isolate_module(os)
@ -37,7 +43,7 @@ try:
HAS_CTRACER = True HAS_CTRACER = True
except ImportError: except ImportError:
# Couldn't import the C extension, maybe it isn't built. # Couldn't import the C extension, maybe it isn't built.
if os.getenv("COVERAGE_CORE") == "ctrace": # pragma: part covered if os.getenv('COVERAGE_CORE') == 'ctrace': # pragma: part covered
# During testing, we use the COVERAGE_CORE environment variable # During testing, we use the COVERAGE_CORE environment variable
# to indicate that we've fiddled with the environment to test this # to indicate that we've fiddled with the environment to test this
# fallback code. If we thought we had a C tracer, but couldn't import # fallback code. If we thought we had a C tracer, but couldn't import
@ -48,7 +54,7 @@ except ImportError:
sys.exit(1) sys.exit(1)
HAS_CTRACER = False HAS_CTRACER = False
T = TypeVar("T") T = TypeVar('T')
class Collector: class Collector:
@ -73,7 +79,7 @@ class Collector:
_collectors: list[Collector] = [] _collectors: list[Collector] = []
# The concurrency settings we support here. # The concurrency settings we support here.
LIGHT_THREADS = {"greenlet", "eventlet", "gevent"} LIGHT_THREADS = {'greenlet', 'eventlet', 'gevent'}
def __init__( def __init__(
self, self,
@ -130,7 +136,7 @@ class Collector:
self.branch = branch self.branch = branch
self.warn = warn self.warn = warn
self.concurrency = concurrency self.concurrency = concurrency
assert isinstance(self.concurrency, list), f"Expected a list: {self.concurrency!r}" assert isinstance(self.concurrency, list), f'Expected a list: {self.concurrency!r}'
self.pid = os.getpid() self.pid = os.getpid()
@ -147,12 +153,12 @@ class Collector:
core: str | None core: str | None
if timid: if timid:
core = "pytrace" core = 'pytrace'
else: else:
core = os.getenv("COVERAGE_CORE") core = os.getenv('COVERAGE_CORE')
if core == "sysmon" and not env.PYBEHAVIOR.pep669: if core == 'sysmon' and not env.PYBEHAVIOR.pep669:
self.warn("sys.monitoring isn't available, using default core", slug="no-sysmon") self.warn("sys.monitoring isn't available, using default core", slug='no-sysmon')
core = None core = None
if not core: if not core:
@ -160,25 +166,25 @@ class Collector:
# if env.PYBEHAVIOR.pep669 and self.should_start_context is None: # if env.PYBEHAVIOR.pep669 and self.should_start_context is None:
# core = "sysmon" # core = "sysmon"
if HAS_CTRACER: if HAS_CTRACER:
core = "ctrace" core = 'ctrace'
else: else:
core = "pytrace" core = 'pytrace'
if core == "sysmon": if core == 'sysmon':
self._trace_class = SysMonitor self._trace_class = SysMonitor
self._core_kwargs = {"tool_id": 3 if metacov else 1} self._core_kwargs = {'tool_id': 3 if metacov else 1}
self.file_disposition_class = FileDisposition self.file_disposition_class = FileDisposition
self.supports_plugins = False self.supports_plugins = False
self.packed_arcs = False self.packed_arcs = False
self.systrace = False self.systrace = False
elif core == "ctrace": elif core == 'ctrace':
self._trace_class = CTracer self._trace_class = CTracer
self._core_kwargs = {} self._core_kwargs = {}
self.file_disposition_class = CFileDisposition self.file_disposition_class = CFileDisposition
self.supports_plugins = True self.supports_plugins = True
self.packed_arcs = True self.packed_arcs = True
self.systrace = True self.systrace = True
elif core == "pytrace": elif core == 'pytrace':
self._trace_class = PyTracer self._trace_class = PyTracer
self._core_kwargs = {} self._core_kwargs = {}
self.file_disposition_class = FileDisposition self.file_disposition_class = FileDisposition
@ -186,42 +192,42 @@ class Collector:
self.packed_arcs = False self.packed_arcs = False
self.systrace = True self.systrace = True
else: else:
raise ConfigError(f"Unknown core value: {core!r}") raise ConfigError(f'Unknown core value: {core!r}')
# We can handle a few concurrency options here, but only one at a time. # We can handle a few concurrency options here, but only one at a time.
concurrencies = set(self.concurrency) concurrencies = set(self.concurrency)
unknown = concurrencies - CoverageConfig.CONCURRENCY_CHOICES unknown = concurrencies - CoverageConfig.CONCURRENCY_CHOICES
if unknown: if unknown:
show = ", ".join(sorted(unknown)) show = ', '.join(sorted(unknown))
raise ConfigError(f"Unknown concurrency choices: {show}") raise ConfigError(f'Unknown concurrency choices: {show}')
light_threads = concurrencies & self.LIGHT_THREADS light_threads = concurrencies & self.LIGHT_THREADS
if len(light_threads) > 1: if len(light_threads) > 1:
show = ", ".join(sorted(light_threads)) show = ', '.join(sorted(light_threads))
raise ConfigError(f"Conflicting concurrency settings: {show}") raise ConfigError(f'Conflicting concurrency settings: {show}')
do_threading = False do_threading = False
tried = "nothing" # to satisfy pylint tried = 'nothing' # to satisfy pylint
try: try:
if "greenlet" in concurrencies: if 'greenlet' in concurrencies:
tried = "greenlet" tried = 'greenlet'
import greenlet import greenlet
self.concur_id_func = greenlet.getcurrent self.concur_id_func = greenlet.getcurrent
elif "eventlet" in concurrencies: elif 'eventlet' in concurrencies:
tried = "eventlet" tried = 'eventlet'
import eventlet.greenthread # pylint: disable=import-error,useless-suppression import eventlet.greenthread # pylint: disable=import-error,useless-suppression
self.concur_id_func = eventlet.greenthread.getcurrent self.concur_id_func = eventlet.greenthread.getcurrent
elif "gevent" in concurrencies: elif 'gevent' in concurrencies:
tried = "gevent" tried = 'gevent'
import gevent # pylint: disable=import-error,useless-suppression import gevent # pylint: disable=import-error,useless-suppression
self.concur_id_func = gevent.getcurrent self.concur_id_func = gevent.getcurrent
if "thread" in concurrencies: if 'thread' in concurrencies:
do_threading = True do_threading = True
except ImportError as ex: except ImportError as ex:
msg = f"Couldn't trace with concurrency={tried}, the module isn't installed." msg = f"Couldn't trace with concurrency={tried}, the module isn't installed."
raise ConfigError(msg) from ex raise ConfigError(msg) from ex
if self.concur_id_func and not hasattr(self._trace_class, "concur_id_func"): if self.concur_id_func and not hasattr(self._trace_class, 'concur_id_func'):
raise ConfigError( raise ConfigError(
"Can't support concurrency={} with {}, only threads are supported.".format( "Can't support concurrency={} with {}, only threads are supported.".format(
tried, self.tracer_name(), tried, self.tracer_name(),
@ -238,7 +244,7 @@ class Collector:
self.reset() self.reset()
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Collector at {id(self):#x}: {self.tracer_name()}>" return f'<Collector at {id(self):#x}: {self.tracer_name()}>'
def use_data(self, covdata: CoverageData, context: str | None) -> None: def use_data(self, covdata: CoverageData, context: str | None) -> None:
"""Use `covdata` for recording data.""" """Use `covdata` for recording data."""
@ -296,7 +302,7 @@ class Collector:
# #
# This gives a 20% benefit on the workload described at # This gives a 20% benefit on the workload described at
# https://bitbucket.org/pypy/pypy/issue/1871/10x-slower-than-cpython-under-coverage # https://bitbucket.org/pypy/pypy/issue/1871/10x-slower-than-cpython-under-coverage
self.should_trace_cache = __pypy__.newdict("module") self.should_trace_cache = __pypy__.newdict('module')
else: else:
self.should_trace_cache = {} self.should_trace_cache = {}
@ -394,9 +400,9 @@ class Collector:
"""Stop collecting trace information.""" """Stop collecting trace information."""
assert self._collectors assert self._collectors
if self._collectors[-1] is not self: if self._collectors[-1] is not self:
print("self._collectors:") print('self._collectors:')
for c in self._collectors: for c in self._collectors:
print(f" {c!r}\n{c.origin}") print(f' {c!r}\n{c.origin}')
assert self._collectors[-1] is self, ( assert self._collectors[-1] is self, (
f"Expected current collector to be {self!r}, but it's {self._collectors[-1]!r}" f"Expected current collector to be {self!r}, but it's {self._collectors[-1]!r}"
) )
@ -414,9 +420,9 @@ class Collector:
tracer.stop() tracer.stop()
stats = tracer.get_stats() stats = tracer.get_stats()
if stats: if stats:
print("\nCoverage.py tracer stats:") print('\nCoverage.py tracer stats:')
for k, v in human_sorted_items(stats.items()): for k, v in human_sorted_items(stats.items()):
print(f"{k:>20}: {v}") print(f'{k:>20}: {v}')
if self.threading: if self.threading:
self.threading.settrace(None) self.threading.settrace(None)
@ -433,7 +439,7 @@ class Collector:
def post_fork(self) -> None: def post_fork(self) -> None:
"""After a fork, tracers might need to adjust.""" """After a fork, tracers might need to adjust."""
for tracer in self.tracers: for tracer in self.tracers:
if hasattr(tracer, "post_fork"): if hasattr(tracer, 'post_fork'):
tracer.post_fork() tracer.post_fork()
def _activity(self) -> bool: def _activity(self) -> bool:
@ -451,7 +457,7 @@ class Collector:
if self.static_context: if self.static_context:
context = self.static_context context = self.static_context
if new_context: if new_context:
context += "|" + new_context context += '|' + new_context
else: else:
context = new_context context = new_context
self.covdata.set_context(context) self.covdata.set_context(context)
@ -462,7 +468,7 @@ class Collector:
assert file_tracer is not None assert file_tracer is not None
plugin = file_tracer._coverage_plugin plugin = file_tracer._coverage_plugin
plugin_name = plugin._coverage_plugin_name plugin_name = plugin._coverage_plugin_name
self.warn(f"Disabling plug-in {plugin_name!r} due to previous exception") self.warn(f'Disabling plug-in {plugin_name!r} due to previous exception')
plugin._coverage_enabled = False plugin._coverage_enabled = False
disposition.trace = False disposition.trace = False

View file

@ -1,28 +1,30 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Config file for coverage.py""" """Config file for coverage.py"""
from __future__ import annotations from __future__ import annotations
import collections import collections
import configparser import configparser
import copy import copy
import os
import os.path import os.path
import re import re
from typing import Any
from typing import ( from typing import Callable
Any, Callable, Iterable, Union, from typing import Iterable
) from typing import Union
from coverage.exceptions import ConfigError from coverage.exceptions import ConfigError
from coverage.misc import isolate_module, human_sorted_items, substitute_variables from coverage.misc import human_sorted_items
from coverage.tomlconfig import TomlConfigParser, TomlDecodeError from coverage.misc import isolate_module
from coverage.types import ( from coverage.misc import substitute_variables
TConfigurable, TConfigSectionIn, TConfigValueIn, TConfigSectionOut, from coverage.tomlconfig import TomlConfigParser
TConfigValueOut, TPluginConfig, from coverage.tomlconfig import TomlDecodeError
) from coverage.types import TConfigSectionIn
from coverage.types import TConfigSectionOut
from coverage.types import TConfigurable
from coverage.types import TConfigValueIn
from coverage.types import TConfigValueOut
from coverage.types import TPluginConfig
os = isolate_module(os) os = isolate_module(os)
@ -39,17 +41,17 @@ class HandyConfigParser(configparser.ConfigParser):
""" """
super().__init__(interpolation=None) super().__init__(interpolation=None)
self.section_prefixes = ["coverage:"] self.section_prefixes = ['coverage:']
if our_file: if our_file:
self.section_prefixes.append("") self.section_prefixes.append('')
def read( # type: ignore[override] def read( # type: ignore[override]
self, self,
filenames: Iterable[str], filenames: Iterable[str],
encoding_unused: str | None = None, encoding_unused: str | None = None,
) -> list[str]: ) -> list[str]:
"""Read a file name as UTF-8 configuration data.""" """Read a file name as UTF-8 configuration data."""
return super().read(filenames, encoding="utf-8") return super().read(filenames, encoding='utf-8')
def real_section(self, section: str) -> str | None: def real_section(self, section: str) -> str | None:
"""Get the actual name of a section.""" """Get the actual name of a section."""
@ -73,7 +75,7 @@ class HandyConfigParser(configparser.ConfigParser):
real_section = self.real_section(section) real_section = self.real_section(section)
if real_section is not None: if real_section is not None:
return super().options(real_section) return super().options(real_section)
raise ConfigError(f"No section: {section!r}") raise ConfigError(f'No section: {section!r}')
def get_section(self, section: str) -> TConfigSectionOut: def get_section(self, section: str) -> TConfigSectionOut:
"""Get the contents of a section, as a dictionary.""" """Get the contents of a section, as a dictionary."""
@ -82,7 +84,7 @@ class HandyConfigParser(configparser.ConfigParser):
d[opt] = self.get(section, opt) d[opt] = self.get(section, opt)
return d return d
def get(self, section: str, option: str, *args: Any, **kwargs: Any) -> str: # type: ignore def get(self, section: str, option: str, *args: Any, **kwargs: Any) -> str: # type: ignore
"""Get a value, replacing environment variables also. """Get a value, replacing environment variables also.
The arguments are the same as `ConfigParser.get`, but in the found The arguments are the same as `ConfigParser.get`, but in the found
@ -97,7 +99,7 @@ class HandyConfigParser(configparser.ConfigParser):
if super().has_option(real_section, option): if super().has_option(real_section, option):
break break
else: else:
raise ConfigError(f"No option {option!r} in section: {section!r}") raise ConfigError(f'No option {option!r} in section: {section!r}')
v: str = super().get(real_section, option, *args, **kwargs) v: str = super().get(real_section, option, *args, **kwargs)
v = substitute_variables(v, os.environ) v = substitute_variables(v, os.environ)
@ -114,8 +116,8 @@ class HandyConfigParser(configparser.ConfigParser):
""" """
value_list = self.get(section, option) value_list = self.get(section, option)
values = [] values = []
for value_line in value_list.split("\n"): for value_line in value_list.split('\n'):
for value in value_line.split(","): for value in value_line.split(','):
value = value.strip() value = value.strip()
if value: if value:
values.append(value) values.append(value)
@ -138,7 +140,7 @@ class HandyConfigParser(configparser.ConfigParser):
re.compile(value) re.compile(value)
except re.error as e: except re.error as e:
raise ConfigError( raise ConfigError(
f"Invalid [{section}].{option} value {value!r}: {e}", f'Invalid [{section}].{option} value {value!r}: {e}',
) from e ) from e
if value: if value:
value_list.append(value) value_list.append(value)
@ -150,20 +152,20 @@ TConfigParser = Union[HandyConfigParser, TomlConfigParser]
# The default line exclusion regexes. # The default line exclusion regexes.
DEFAULT_EXCLUDE = [ DEFAULT_EXCLUDE = [
r"#\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(cover|COVER)", r'#\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(cover|COVER)',
] ]
# The default partial branch regexes, to be modified by the user. # The default partial branch regexes, to be modified by the user.
DEFAULT_PARTIAL = [ DEFAULT_PARTIAL = [
r"#\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(branch|BRANCH)", r'#\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(branch|BRANCH)',
] ]
# The default partial branch regexes, based on Python semantics. # The default partial branch regexes, based on Python semantics.
# These are any Python branching constructs that can't actually execute all # These are any Python branching constructs that can't actually execute all
# their branches. # their branches.
DEFAULT_PARTIAL_ALWAYS = [ DEFAULT_PARTIAL_ALWAYS = [
"while (True|1|False|0):", 'while (True|1|False|0):',
"if (True|1|False|0):", 'if (True|1|False|0):',
] ]
@ -197,7 +199,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
self.concurrency: list[str] = [] self.concurrency: list[str] = []
self.context: str | None = None self.context: str | None = None
self.cover_pylib = False self.cover_pylib = False
self.data_file = ".coverage" self.data_file = '.coverage'
self.debug: list[str] = [] self.debug: list[str] = []
self.debug_file: str | None = None self.debug_file: str | None = None
self.disable_warnings: list[str] = [] self.disable_warnings: list[str] = []
@ -233,23 +235,23 @@ class CoverageConfig(TConfigurable, TPluginConfig):
# Defaults for [html] # Defaults for [html]
self.extra_css: str | None = None self.extra_css: str | None = None
self.html_dir = "htmlcov" self.html_dir = 'htmlcov'
self.html_skip_covered: bool | None = None self.html_skip_covered: bool | None = None
self.html_skip_empty: bool | None = None self.html_skip_empty: bool | None = None
self.html_title = "Coverage report" self.html_title = 'Coverage report'
self.show_contexts = False self.show_contexts = False
# Defaults for [xml] # Defaults for [xml]
self.xml_output = "coverage.xml" self.xml_output = 'coverage.xml'
self.xml_package_depth = 99 self.xml_package_depth = 99
# Defaults for [json] # Defaults for [json]
self.json_output = "coverage.json" self.json_output = 'coverage.json'
self.json_pretty_print = False self.json_pretty_print = False
self.json_show_contexts = False self.json_show_contexts = False
# Defaults for [lcov] # Defaults for [lcov]
self.lcov_output = "coverage.lcov" self.lcov_output = 'coverage.lcov'
# Defaults for [paths] # Defaults for [paths]
self.paths: dict[str, list[str]] = {} self.paths: dict[str, list[str]] = {}
@ -258,9 +260,9 @@ class CoverageConfig(TConfigurable, TPluginConfig):
self.plugin_options: dict[str, TConfigSectionOut] = {} self.plugin_options: dict[str, TConfigSectionOut] = {}
MUST_BE_LIST = { MUST_BE_LIST = {
"debug", "concurrency", "plugins", 'debug', 'concurrency', 'plugins',
"report_omit", "report_include", 'report_omit', 'report_include',
"run_omit", "run_include", 'run_omit', 'run_include',
} }
def from_args(self, **kwargs: TConfigValueIn) -> None: def from_args(self, **kwargs: TConfigValueIn) -> None:
@ -286,7 +288,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
""" """
_, ext = os.path.splitext(filename) _, ext = os.path.splitext(filename)
cp: TConfigParser cp: TConfigParser
if ext == ".toml": if ext == '.toml':
cp = TomlConfigParser(our_file) cp = TomlConfigParser(our_file)
else: else:
cp = HandyConfigParser(our_file) cp = HandyConfigParser(our_file)
@ -314,7 +316,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
# Check that there are no unrecognized options. # Check that there are no unrecognized options.
all_options = collections.defaultdict(set) all_options = collections.defaultdict(set)
for option_spec in self.CONFIG_FILE_OPTIONS: for option_spec in self.CONFIG_FILE_OPTIONS:
section, option = option_spec[1].split(":") section, option = option_spec[1].split(':')
all_options[section].add(option) all_options[section].add(option)
for section, options in all_options.items(): for section, options in all_options.items():
@ -328,9 +330,9 @@ class CoverageConfig(TConfigurable, TPluginConfig):
) )
# [paths] is special # [paths] is special
if cp.has_section("paths"): if cp.has_section('paths'):
for option in cp.options("paths"): for option in cp.options('paths'):
self.paths[option] = cp.getlist("paths", option) self.paths[option] = cp.getlist('paths', option)
any_set = True any_set = True
# plugins can have options # plugins can have options
@ -349,7 +351,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
if used: if used:
self.config_file = os.path.abspath(filename) self.config_file = os.path.abspath(filename)
with open(filename, "rb") as f: with open(filename, 'rb') as f:
self._config_contents = f.read() self._config_contents = f.read()
return used return used
@ -358,7 +360,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
"""Return a copy of the configuration.""" """Return a copy of the configuration."""
return copy.deepcopy(self) return copy.deepcopy(self)
CONCURRENCY_CHOICES = {"thread", "gevent", "greenlet", "eventlet", "multiprocessing"} CONCURRENCY_CHOICES = {'thread', 'gevent', 'greenlet', 'eventlet', 'multiprocessing'}
CONFIG_FILE_OPTIONS = [ CONFIG_FILE_OPTIONS = [
# These are *args for _set_attr_from_config_option: # These are *args for _set_attr_from_config_option:
@ -370,64 +372,64 @@ class CoverageConfig(TConfigurable, TPluginConfig):
# configuration value from the file. # configuration value from the file.
# [run] # [run]
("branch", "run:branch", "boolean"), ('branch', 'run:branch', 'boolean'),
("command_line", "run:command_line"), ('command_line', 'run:command_line'),
("concurrency", "run:concurrency", "list"), ('concurrency', 'run:concurrency', 'list'),
("context", "run:context"), ('context', 'run:context'),
("cover_pylib", "run:cover_pylib", "boolean"), ('cover_pylib', 'run:cover_pylib', 'boolean'),
("data_file", "run:data_file"), ('data_file', 'run:data_file'),
("debug", "run:debug", "list"), ('debug', 'run:debug', 'list'),
("debug_file", "run:debug_file"), ('debug_file', 'run:debug_file'),
("disable_warnings", "run:disable_warnings", "list"), ('disable_warnings', 'run:disable_warnings', 'list'),
("dynamic_context", "run:dynamic_context"), ('dynamic_context', 'run:dynamic_context'),
("parallel", "run:parallel", "boolean"), ('parallel', 'run:parallel', 'boolean'),
("plugins", "run:plugins", "list"), ('plugins', 'run:plugins', 'list'),
("relative_files", "run:relative_files", "boolean"), ('relative_files', 'run:relative_files', 'boolean'),
("run_include", "run:include", "list"), ('run_include', 'run:include', 'list'),
("run_omit", "run:omit", "list"), ('run_omit', 'run:omit', 'list'),
("sigterm", "run:sigterm", "boolean"), ('sigterm', 'run:sigterm', 'boolean'),
("source", "run:source", "list"), ('source', 'run:source', 'list'),
("source_pkgs", "run:source_pkgs", "list"), ('source_pkgs', 'run:source_pkgs', 'list'),
("timid", "run:timid", "boolean"), ('timid', 'run:timid', 'boolean'),
("_crash", "run:_crash"), ('_crash', 'run:_crash'),
# [report] # [report]
("exclude_list", "report:exclude_lines", "regexlist"), ('exclude_list', 'report:exclude_lines', 'regexlist'),
("exclude_also", "report:exclude_also", "regexlist"), ('exclude_also', 'report:exclude_also', 'regexlist'),
("fail_under", "report:fail_under", "float"), ('fail_under', 'report:fail_under', 'float'),
("format", "report:format"), ('format', 'report:format'),
("ignore_errors", "report:ignore_errors", "boolean"), ('ignore_errors', 'report:ignore_errors', 'boolean'),
("include_namespace_packages", "report:include_namespace_packages", "boolean"), ('include_namespace_packages', 'report:include_namespace_packages', 'boolean'),
("partial_always_list", "report:partial_branches_always", "regexlist"), ('partial_always_list', 'report:partial_branches_always', 'regexlist'),
("partial_list", "report:partial_branches", "regexlist"), ('partial_list', 'report:partial_branches', 'regexlist'),
("precision", "report:precision", "int"), ('precision', 'report:precision', 'int'),
("report_contexts", "report:contexts", "list"), ('report_contexts', 'report:contexts', 'list'),
("report_include", "report:include", "list"), ('report_include', 'report:include', 'list'),
("report_omit", "report:omit", "list"), ('report_omit', 'report:omit', 'list'),
("show_missing", "report:show_missing", "boolean"), ('show_missing', 'report:show_missing', 'boolean'),
("skip_covered", "report:skip_covered", "boolean"), ('skip_covered', 'report:skip_covered', 'boolean'),
("skip_empty", "report:skip_empty", "boolean"), ('skip_empty', 'report:skip_empty', 'boolean'),
("sort", "report:sort"), ('sort', 'report:sort'),
# [html] # [html]
("extra_css", "html:extra_css"), ('extra_css', 'html:extra_css'),
("html_dir", "html:directory"), ('html_dir', 'html:directory'),
("html_skip_covered", "html:skip_covered", "boolean"), ('html_skip_covered', 'html:skip_covered', 'boolean'),
("html_skip_empty", "html:skip_empty", "boolean"), ('html_skip_empty', 'html:skip_empty', 'boolean'),
("html_title", "html:title"), ('html_title', 'html:title'),
("show_contexts", "html:show_contexts", "boolean"), ('show_contexts', 'html:show_contexts', 'boolean'),
# [xml] # [xml]
("xml_output", "xml:output"), ('xml_output', 'xml:output'),
("xml_package_depth", "xml:package_depth", "int"), ('xml_package_depth', 'xml:package_depth', 'int'),
# [json] # [json]
("json_output", "json:output"), ('json_output', 'json:output'),
("json_pretty_print", "json:pretty_print", "boolean"), ('json_pretty_print', 'json:pretty_print', 'boolean'),
("json_show_contexts", "json:show_contexts", "boolean"), ('json_show_contexts', 'json:show_contexts', 'boolean'),
# [lcov] # [lcov]
("lcov_output", "lcov:output"), ('lcov_output', 'lcov:output'),
] ]
def _set_attr_from_config_option( def _set_attr_from_config_option(
@ -435,16 +437,16 @@ class CoverageConfig(TConfigurable, TPluginConfig):
cp: TConfigParser, cp: TConfigParser,
attr: str, attr: str,
where: str, where: str,
type_: str = "", type_: str = '',
) -> bool: ) -> bool:
"""Set an attribute on self if it exists in the ConfigParser. """Set an attribute on self if it exists in the ConfigParser.
Returns True if the attribute was set. Returns True if the attribute was set.
""" """
section, option = where.split(":") section, option = where.split(':')
if cp.has_option(section, option): if cp.has_option(section, option):
method = getattr(cp, "get" + type_) method = getattr(cp, 'get' + type_)
setattr(self, attr, method(section, option)) setattr(self, attr, method(section, option))
return True return True
return False return False
@ -464,7 +466,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
""" """
# Special-cased options. # Special-cased options.
if option_name == "paths": if option_name == 'paths':
self.paths = value # type: ignore[assignment] self.paths = value # type: ignore[assignment]
return return
@ -476,13 +478,13 @@ class CoverageConfig(TConfigurable, TPluginConfig):
return return
# See if it's a plugin option. # See if it's a plugin option.
plugin_name, _, key = option_name.partition(":") plugin_name, _, key = option_name.partition(':')
if key and plugin_name in self.plugins: if key and plugin_name in self.plugins:
self.plugin_options.setdefault(plugin_name, {})[key] = value # type: ignore[index] self.plugin_options.setdefault(plugin_name, {})[key] = value # type: ignore[index]
return return
# If we get here, we didn't find the option. # If we get here, we didn't find the option.
raise ConfigError(f"No such option: {option_name!r}") raise ConfigError(f'No such option: {option_name!r}')
def get_option(self, option_name: str) -> TConfigValueOut | None: def get_option(self, option_name: str) -> TConfigValueOut | None:
"""Get an option from the configuration. """Get an option from the configuration.
@ -495,7 +497,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
""" """
# Special-cased options. # Special-cased options.
if option_name == "paths": if option_name == 'paths':
return self.paths # type: ignore[return-value] return self.paths # type: ignore[return-value]
# Check all the hard-coded options. # Check all the hard-coded options.
@ -505,12 +507,12 @@ class CoverageConfig(TConfigurable, TPluginConfig):
return getattr(self, attr) # type: ignore[no-any-return] return getattr(self, attr) # type: ignore[no-any-return]
# See if it's a plugin option. # See if it's a plugin option.
plugin_name, _, key = option_name.partition(":") plugin_name, _, key = option_name.partition(':')
if key and plugin_name in self.plugins: if key and plugin_name in self.plugins:
return self.plugin_options.get(plugin_name, {}).get(key) return self.plugin_options.get(plugin_name, {}).get(key)
# If we get here, we didn't find the option. # If we get here, we didn't find the option.
raise ConfigError(f"No such option: {option_name!r}") raise ConfigError(f'No such option: {option_name!r}')
def post_process_file(self, path: str) -> str: def post_process_file(self, path: str) -> str:
"""Make final adjustments to a file path to make it usable.""" """Make final adjustments to a file path to make it usable."""
@ -530,7 +532,7 @@ class CoverageConfig(TConfigurable, TPluginConfig):
def debug_info(self) -> list[tuple[str, Any]]: def debug_info(self) -> list[tuple[str, Any]]:
"""Make a list of (name, value) pairs for writing debug info.""" """Make a list of (name, value) pairs for writing debug info."""
return human_sorted_items( return human_sorted_items(
(k, v) for k, v in self.__dict__.items() if not k.startswith("_") (k, v) for k, v in self.__dict__.items() if not k.startswith('_')
) )
@ -543,24 +545,24 @@ def config_files_to_try(config_file: bool | str) -> list[tuple[str, bool, bool]]
# Some API users were specifying ".coveragerc" to mean the same as # Some API users were specifying ".coveragerc" to mean the same as
# True, so make it so. # True, so make it so.
if config_file == ".coveragerc": if config_file == '.coveragerc':
config_file = True config_file = True
specified_file = (config_file is not True) specified_file = (config_file is not True)
if not specified_file: if not specified_file:
# No file was specified. Check COVERAGE_RCFILE. # No file was specified. Check COVERAGE_RCFILE.
rcfile = os.getenv("COVERAGE_RCFILE") rcfile = os.getenv('COVERAGE_RCFILE')
if rcfile: if rcfile:
config_file = rcfile config_file = rcfile
specified_file = True specified_file = True
if not specified_file: if not specified_file:
# Still no file specified. Default to .coveragerc # Still no file specified. Default to .coveragerc
config_file = ".coveragerc" config_file = '.coveragerc'
assert isinstance(config_file, str) assert isinstance(config_file, str)
files_to_try = [ files_to_try = [
(config_file, True, specified_file), (config_file, True, specified_file),
("setup.cfg", False, False), ('setup.cfg', False, False),
("tox.ini", False, False), ('tox.ini', False, False),
("pyproject.toml", False, False), ('pyproject.toml', False, False),
] ]
return files_to_try return files_to_try
@ -601,13 +603,13 @@ def read_coverage_config(
raise ConfigError(f"Couldn't read {fname!r} as a config file") raise ConfigError(f"Couldn't read {fname!r} as a config file")
# 3) from environment variables: # 3) from environment variables:
env_data_file = os.getenv("COVERAGE_FILE") env_data_file = os.getenv('COVERAGE_FILE')
if env_data_file: if env_data_file:
config.data_file = env_data_file config.data_file = env_data_file
# $set_env.py: COVERAGE_DEBUG - Debug options: https://coverage.rtfd.io/cmd.html#debug # $set_env.py: COVERAGE_DEBUG - Debug options: https://coverage.rtfd.io/cmd.html#debug
debugs = os.getenv("COVERAGE_DEBUG") debugs = os.getenv('COVERAGE_DEBUG')
if debugs: if debugs:
config.debug.extend(d.strip() for d in debugs.split(",")) config.debug.extend(d.strip() for d in debugs.split(','))
# 4) from constructor arguments: # 4) from constructor arguments:
config.from_args(**kwargs) config.from_args(**kwargs)

View file

@ -1,12 +1,12 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Determine contexts for coverage.py""" """Determine contexts for coverage.py"""
from __future__ import annotations from __future__ import annotations
from types import FrameType from types import FrameType
from typing import cast, Callable, Sequence from typing import Callable
from typing import cast
from typing import Sequence
def combine_context_switchers( def combine_context_switchers(
@ -44,7 +44,7 @@ def combine_context_switchers(
def should_start_context_test_function(frame: FrameType) -> str | None: def should_start_context_test_function(frame: FrameType) -> str | None:
"""Is this frame calling a test_* function?""" """Is this frame calling a test_* function?"""
co_name = frame.f_code.co_name co_name = frame.f_code.co_name
if co_name.startswith("test") or co_name == "runTest": if co_name.startswith('test') or co_name == 'runTest':
return qualname_from_frame(frame) return qualname_from_frame(frame)
return None return None
@ -54,19 +54,19 @@ def qualname_from_frame(frame: FrameType) -> str | None:
co = frame.f_code co = frame.f_code
fname = co.co_name fname = co.co_name
method = None method = None
if co.co_argcount and co.co_varnames[0] == "self": if co.co_argcount and co.co_varnames[0] == 'self':
self = frame.f_locals.get("self", None) self = frame.f_locals.get('self', None)
method = getattr(self, fname, None) method = getattr(self, fname, None)
if method is None: if method is None:
func = frame.f_globals.get(fname) func = frame.f_globals.get(fname)
if func is None: if func is None:
return None return None
return cast(str, func.__module__ + "." + fname) return cast(str, func.__module__ + '.' + fname)
func = getattr(method, "__func__", None) func = getattr(method, '__func__', None)
if func is None: if func is None:
cls = self.__class__ cls = self.__class__
return cast(str, cls.__module__ + "." + cls.__name__ + "." + fname) return cast(str, cls.__module__ + '.' + cls.__name__ + '.' + fname)
return cast(str, func.__module__ + "." + func.__qualname__) return cast(str, func.__module__ + '.' + func.__qualname__)

View file

@ -1,14 +1,11 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Central control stuff for coverage.py.""" """Central control stuff for coverage.py."""
from __future__ import annotations from __future__ import annotations
import atexit import atexit
import collections import collections
import contextlib import contextlib
import os
import os.path import os.path
import platform import platform
import signal import signal
@ -16,31 +13,48 @@ import sys
import threading import threading
import time import time
import warnings import warnings
from types import FrameType from types import FrameType
from typing import ( from typing import Any
cast, from typing import Callable
Any, Callable, IO, Iterable, Iterator, List, from typing import cast
) from typing import IO
from typing import Iterable
from typing import Iterator
from typing import List
from coverage import env from coverage import env
from coverage.annotate import AnnotateReporter from coverage.annotate import AnnotateReporter
from coverage.collector import Collector, HAS_CTRACER from coverage.collector import Collector
from coverage.config import CoverageConfig, read_coverage_config from coverage.collector import HAS_CTRACER
from coverage.context import should_start_context_test_function, combine_context_switchers from coverage.config import CoverageConfig
from coverage.data import CoverageData, combine_parallel_data from coverage.config import read_coverage_config
from coverage.debug import ( from coverage.context import combine_context_switchers
DebugControl, NoDebugging, short_stack, write_formatted_info, relevant_environment_display, from coverage.context import should_start_context_test_function
) from coverage.data import combine_parallel_data
from coverage.data import CoverageData
from coverage.debug import DebugControl
from coverage.debug import NoDebugging
from coverage.debug import relevant_environment_display
from coverage.debug import short_stack
from coverage.debug import write_formatted_info
from coverage.disposition import disposition_debug_msg from coverage.disposition import disposition_debug_msg
from coverage.exceptions import ConfigError, CoverageException, CoverageWarning, PluginError from coverage.exceptions import ConfigError
from coverage.files import PathAliases, abs_file, relative_filename, set_relative_directory from coverage.exceptions import CoverageException
from coverage.exceptions import CoverageWarning
from coverage.exceptions import PluginError
from coverage.files import abs_file
from coverage.files import PathAliases
from coverage.files import relative_filename
from coverage.files import set_relative_directory
from coverage.html import HtmlReporter from coverage.html import HtmlReporter
from coverage.inorout import InOrOut from coverage.inorout import InOrOut
from coverage.jsonreport import JsonReporter from coverage.jsonreport import JsonReporter
from coverage.lcovreport import LcovReporter from coverage.lcovreport import LcovReporter
from coverage.misc import bool_or_none, join_regex from coverage.misc import bool_or_none
from coverage.misc import DefaultValue, ensure_dir_for_file, isolate_module from coverage.misc import DefaultValue
from coverage.misc import ensure_dir_for_file
from coverage.misc import isolate_module
from coverage.misc import join_regex
from coverage.multiproc import patch_multiprocessing from coverage.multiproc import patch_multiprocessing
from coverage.plugin import FileReporter from coverage.plugin import FileReporter
from coverage.plugin_support import Plugins from coverage.plugin_support import Plugins
@ -48,14 +62,19 @@ from coverage.python import PythonFileReporter
from coverage.report import SummaryReporter from coverage.report import SummaryReporter
from coverage.report_core import render_report from coverage.report_core import render_report
from coverage.results import Analysis from coverage.results import Analysis
from coverage.types import ( from coverage.types import FilePath
FilePath, TConfigurable, TConfigSectionIn, TConfigValueIn, TConfigValueOut, from coverage.types import TConfigSectionIn
TFileDisposition, TLineNo, TMorf, from coverage.types import TConfigurable
) from coverage.types import TConfigValueIn
from coverage.types import TConfigValueOut
from coverage.types import TFileDisposition
from coverage.types import TLineNo
from coverage.types import TMorf
from coverage.xmlreport import XmlReporter from coverage.xmlreport import XmlReporter
os = isolate_module(os) os = isolate_module(os)
@contextlib.contextmanager @contextlib.contextmanager
def override_config(cov: Coverage, **kwargs: TConfigValueIn) -> Iterator[None]: def override_config(cov: Coverage, **kwargs: TConfigValueIn) -> Iterator[None]:
"""Temporarily tweak the configuration of `cov`. """Temporarily tweak the configuration of `cov`.
@ -72,9 +91,10 @@ def override_config(cov: Coverage, **kwargs: TConfigValueIn) -> Iterator[None]:
cov.config = original_config cov.config = original_config
DEFAULT_DATAFILE = DefaultValue("MISSING") DEFAULT_DATAFILE = DefaultValue('MISSING')
_DEFAULT_DATAFILE = DEFAULT_DATAFILE # Just in case, for backwards compatibility _DEFAULT_DATAFILE = DEFAULT_DATAFILE # Just in case, for backwards compatibility
class Coverage(TConfigurable): class Coverage(TConfigurable):
"""Programmatic access to coverage.py. """Programmatic access to coverage.py.
@ -323,10 +343,10 @@ class Coverage(TConfigurable):
# Create and configure the debugging controller. # Create and configure the debugging controller.
self._debug = DebugControl(self.config.debug, self._debug_file, self.config.debug_file) self._debug = DebugControl(self.config.debug, self._debug_file, self.config.debug_file)
if self._debug.should("process"): if self._debug.should('process'):
self._debug.write("Coverage._init") self._debug.write('Coverage._init')
if "multiprocessing" in (self.config.concurrency or ()): if 'multiprocessing' in (self.config.concurrency or ()):
# Multi-processing uses parallel for the subprocesses, so also use # Multi-processing uses parallel for the subprocesses, so also use
# it for the main process. # it for the main process.
self.config.parallel = True self.config.parallel = True
@ -358,31 +378,31 @@ class Coverage(TConfigurable):
# "[run] _crash" will raise an exception if the value is close by in # "[run] _crash" will raise an exception if the value is close by in
# the call stack, for testing error handling. # the call stack, for testing error handling.
if self.config._crash and self.config._crash in short_stack(): if self.config._crash and self.config._crash in short_stack():
raise RuntimeError(f"Crashing because called by {self.config._crash}") raise RuntimeError(f'Crashing because called by {self.config._crash}')
def _write_startup_debug(self) -> None: def _write_startup_debug(self) -> None:
"""Write out debug info at startup if needed.""" """Write out debug info at startup if needed."""
wrote_any = False wrote_any = False
with self._debug.without_callers(): with self._debug.without_callers():
if self._debug.should("config"): if self._debug.should('config'):
config_info = self.config.debug_info() config_info = self.config.debug_info()
write_formatted_info(self._debug.write, "config", config_info) write_formatted_info(self._debug.write, 'config', config_info)
wrote_any = True wrote_any = True
if self._debug.should("sys"): if self._debug.should('sys'):
write_formatted_info(self._debug.write, "sys", self.sys_info()) write_formatted_info(self._debug.write, 'sys', self.sys_info())
for plugin in self._plugins: for plugin in self._plugins:
header = "sys: " + plugin._coverage_plugin_name header = 'sys: ' + plugin._coverage_plugin_name
info = plugin.sys_info() info = plugin.sys_info()
write_formatted_info(self._debug.write, header, info) write_formatted_info(self._debug.write, header, info)
wrote_any = True wrote_any = True
if self._debug.should("pybehave"): if self._debug.should('pybehave'):
write_formatted_info(self._debug.write, "pybehave", env.debug_info()) write_formatted_info(self._debug.write, 'pybehave', env.debug_info())
wrote_any = True wrote_any = True
if wrote_any: if wrote_any:
write_formatted_info(self._debug.write, "end", ()) write_formatted_info(self._debug.write, 'end', ())
def _should_trace(self, filename: str, frame: FrameType) -> TFileDisposition: def _should_trace(self, filename: str, frame: FrameType) -> TFileDisposition:
"""Decide whether to trace execution in `filename`. """Decide whether to trace execution in `filename`.
@ -392,7 +412,7 @@ class Coverage(TConfigurable):
""" """
assert self._inorout is not None assert self._inorout is not None
disp = self._inorout.should_trace(filename, frame) disp = self._inorout.should_trace(filename, frame)
if self._debug.should("trace"): if self._debug.should('trace'):
self._debug.write(disposition_debug_msg(disp)) self._debug.write(disposition_debug_msg(disp))
return disp return disp
@ -404,11 +424,11 @@ class Coverage(TConfigurable):
""" """
assert self._inorout is not None assert self._inorout is not None
reason = self._inorout.check_include_omit_etc(filename, frame) reason = self._inorout.check_include_omit_etc(filename, frame)
if self._debug.should("trace"): if self._debug.should('trace'):
if not reason: if not reason:
msg = f"Including {filename!r}" msg = f'Including {filename!r}'
else: else:
msg = f"Not including {filename!r}: {reason}" msg = f'Not including {filename!r}: {reason}'
self._debug.write(msg) self._debug.write(msg)
return not reason return not reason
@ -431,9 +451,9 @@ class Coverage(TConfigurable):
self._warnings.append(msg) self._warnings.append(msg)
if slug: if slug:
msg = f"{msg} ({slug})" msg = f'{msg} ({slug})'
if self._debug.should("pid"): if self._debug.should('pid'):
msg = f"[{os.getpid()}] {msg}" msg = f'[{os.getpid()}] {msg}'
warnings.warn(msg, category=CoverageWarning, stacklevel=2) warnings.warn(msg, category=CoverageWarning, stacklevel=2)
if once: if once:
@ -512,15 +532,15 @@ class Coverage(TConfigurable):
"""Initialization for start()""" """Initialization for start()"""
# Construct the collector. # Construct the collector.
concurrency: list[str] = self.config.concurrency or [] concurrency: list[str] = self.config.concurrency or []
if "multiprocessing" in concurrency: if 'multiprocessing' in concurrency:
if self.config.config_file is None: if self.config.config_file is None:
raise ConfigError("multiprocessing requires a configuration file") raise ConfigError('multiprocessing requires a configuration file')
patch_multiprocessing(rcfile=self.config.config_file) patch_multiprocessing(rcfile=self.config.config_file)
dycon = self.config.dynamic_context dycon = self.config.dynamic_context
if not dycon or dycon == "none": if not dycon or dycon == 'none':
context_switchers = [] context_switchers = []
elif dycon == "test_function": elif dycon == 'test_function':
context_switchers = [should_start_context_test_function] context_switchers = [should_start_context_test_function]
else: else:
raise ConfigError(f"Don't understand dynamic_context setting: {dycon!r}") raise ConfigError(f"Don't understand dynamic_context setting: {dycon!r}")
@ -565,9 +585,9 @@ class Coverage(TConfigurable):
if self._plugins.file_tracers and not self._collector.supports_plugins: if self._plugins.file_tracers and not self._collector.supports_plugins:
self._warn( self._warn(
"Plugin file tracers ({}) aren't supported with {}".format( "Plugin file tracers ({}) aren't supported with {}".format(
", ".join( ', '.join(
plugin._coverage_plugin_name plugin._coverage_plugin_name
for plugin in self._plugins.file_tracers for plugin in self._plugins.file_tracers
), ),
self._collector.tracer_name(), self._collector.tracer_name(),
), ),
@ -579,7 +599,7 @@ class Coverage(TConfigurable):
self._inorout = InOrOut( self._inorout = InOrOut(
config=self.config, config=self.config,
warn=self._warn, warn=self._warn,
debug=(self._debug if self._debug.should("trace") else None), debug=(self._debug if self._debug.should('trace') else None),
include_namespace_packages=self.config.include_namespace_packages, include_namespace_packages=self.config.include_namespace_packages,
) )
self._inorout.plugins = self._plugins self._inorout.plugins = self._plugins
@ -676,18 +696,18 @@ class Coverage(TConfigurable):
finally: finally:
self.stop() self.stop()
def _atexit(self, event: str = "atexit") -> None: def _atexit(self, event: str = 'atexit') -> None:
"""Clean up on process shutdown.""" """Clean up on process shutdown."""
if self._debug.should("process"): if self._debug.should('process'):
self._debug.write(f"{event}: pid: {os.getpid()}, instance: {self!r}") self._debug.write(f'{event}: pid: {os.getpid()}, instance: {self!r}')
if self._started: if self._started:
self.stop() self.stop()
if self._auto_save or event == "sigterm": if self._auto_save or event == 'sigterm':
self.save() self.save()
def _on_sigterm(self, signum_unused: int, frame_unused: FrameType | None) -> None: def _on_sigterm(self, signum_unused: int, frame_unused: FrameType | None) -> None:
"""A handler for signal.SIGTERM.""" """A handler for signal.SIGTERM."""
self._atexit("sigterm") self._atexit('sigterm')
# Statements after here won't be seen by metacov because we just wrote # Statements after here won't be seen by metacov because we just wrote
# the data, and are about to kill the process. # the data, and are about to kill the process.
signal.signal(signal.SIGTERM, self._old_sigterm) # pragma: not covered signal.signal(signal.SIGTERM, self._old_sigterm) # pragma: not covered
@ -724,21 +744,21 @@ class Coverage(TConfigurable):
""" """
if not self._started: # pragma: part started if not self._started: # pragma: part started
raise CoverageException("Cannot switch context, coverage is not started") raise CoverageException('Cannot switch context, coverage is not started')
assert self._collector is not None assert self._collector is not None
if self._collector.should_start_context: if self._collector.should_start_context:
self._warn("Conflicting dynamic contexts", slug="dynamic-conflict", once=True) self._warn('Conflicting dynamic contexts', slug='dynamic-conflict', once=True)
self._collector.switch_context(new_context) self._collector.switch_context(new_context)
def clear_exclude(self, which: str = "exclude") -> None: def clear_exclude(self, which: str = 'exclude') -> None:
"""Clear the exclude list.""" """Clear the exclude list."""
self._init() self._init()
setattr(self.config, which + "_list", []) setattr(self.config, which + '_list', [])
self._exclude_regex_stale() self._exclude_regex_stale()
def exclude(self, regex: str, which: str = "exclude") -> None: def exclude(self, regex: str, which: str = 'exclude') -> None:
"""Exclude source lines from execution consideration. """Exclude source lines from execution consideration.
A number of lists of regular expressions are maintained. Each list A number of lists of regular expressions are maintained. Each list
@ -754,7 +774,7 @@ class Coverage(TConfigurable):
""" """
self._init() self._init()
excl_list = getattr(self.config, which + "_list") excl_list = getattr(self.config, which + '_list')
excl_list.append(regex) excl_list.append(regex)
self._exclude_regex_stale() self._exclude_regex_stale()
@ -765,11 +785,11 @@ class Coverage(TConfigurable):
def _exclude_regex(self, which: str) -> str: def _exclude_regex(self, which: str) -> str:
"""Return a regex string for the given exclusion list.""" """Return a regex string for the given exclusion list."""
if which not in self._exclude_re: if which not in self._exclude_re:
excl_list = getattr(self.config, which + "_list") excl_list = getattr(self.config, which + '_list')
self._exclude_re[which] = join_regex(excl_list) self._exclude_re[which] = join_regex(excl_list)
return self._exclude_re[which] return self._exclude_re[which]
def get_exclude_list(self, which: str = "exclude") -> list[str]: def get_exclude_list(self, which: str = 'exclude') -> list[str]:
"""Return a list of excluded regex strings. """Return a list of excluded regex strings.
`which` indicates which list is desired. See :meth:`exclude` for the `which` indicates which list is desired. See :meth:`exclude` for the
@ -777,7 +797,7 @@ class Coverage(TConfigurable):
""" """
self._init() self._init()
return cast(List[str], getattr(self.config, which + "_list")) return cast(List[str], getattr(self.config, which + '_list'))
def save(self) -> None: def save(self) -> None:
"""Save the collected coverage data to the data file.""" """Save the collected coverage data to the data file."""
@ -787,7 +807,7 @@ class Coverage(TConfigurable):
def _make_aliases(self) -> PathAliases: def _make_aliases(self) -> PathAliases:
"""Create a PathAliases from our configuration.""" """Create a PathAliases from our configuration."""
aliases = PathAliases( aliases = PathAliases(
debugfn=(self._debug.write if self._debug.should("pathmap") else None), debugfn=(self._debug.write if self._debug.should('pathmap') else None),
relative=self.config.relative_files, relative=self.config.relative_files,
) )
for paths in self.config.paths.values(): for paths in self.config.paths.values():
@ -884,7 +904,7 @@ class Coverage(TConfigurable):
# Find out if we got any data. # Find out if we got any data.
if not self._data and self._warn_no_data: if not self._data and self._warn_no_data:
self._warn("No data was collected.", slug="no-data-collected") self._warn('No data was collected.', slug='no-data-collected')
# Touch all the files that could have executed, so that we can # Touch all the files that could have executed, so that we can
# mark completely un-executed files as 0% covered. # mark completely un-executed files as 0% covered.
@ -952,7 +972,7 @@ class Coverage(TConfigurable):
"""Get a FileReporter for a module or file name.""" """Get a FileReporter for a module or file name."""
assert self._data is not None assert self._data is not None
plugin = None plugin = None
file_reporter: str | FileReporter = "python" file_reporter: str | FileReporter = 'python'
if isinstance(morf, str): if isinstance(morf, str):
mapped_morf = self._file_mapper(morf) mapped_morf = self._file_mapper(morf)
@ -964,12 +984,12 @@ class Coverage(TConfigurable):
file_reporter = plugin.file_reporter(mapped_morf) file_reporter = plugin.file_reporter(mapped_morf)
if file_reporter is None: if file_reporter is None:
raise PluginError( raise PluginError(
"Plugin {!r} did not provide a file reporter for {!r}.".format( 'Plugin {!r} did not provide a file reporter for {!r}.'.format(
plugin._coverage_plugin_name, morf, plugin._coverage_plugin_name, morf,
), ),
) )
if file_reporter == "python": if file_reporter == 'python':
file_reporter = PythonFileReporter(morf, self) file_reporter = PythonFileReporter(morf, self)
assert isinstance(file_reporter, FileReporter) assert isinstance(file_reporter, FileReporter)
@ -1290,36 +1310,37 @@ class Coverage(TConfigurable):
for plugin in plugins: for plugin in plugins:
entry = plugin._coverage_plugin_name entry = plugin._coverage_plugin_name
if not plugin._coverage_enabled: if not plugin._coverage_enabled:
entry += " (disabled)" entry += ' (disabled)'
entries.append(entry) entries.append(entry)
return entries return entries
info = [ info = [
("coverage_version", covmod.__version__), ('coverage_version', covmod.__version__),
("coverage_module", covmod.__file__), ('coverage_module', covmod.__file__),
("core", self._collector.tracer_name() if self._collector is not None else "-none-"), ('core', self._collector.tracer_name() if self._collector is not None else '-none-'),
("CTracer", "available" if HAS_CTRACER else "unavailable"), ('CTracer', 'available' if HAS_CTRACER else 'unavailable'),
("plugins.file_tracers", plugin_info(self._plugins.file_tracers)), ('plugins.file_tracers', plugin_info(self._plugins.file_tracers)),
("plugins.configurers", plugin_info(self._plugins.configurers)), ('plugins.configurers', plugin_info(self._plugins.configurers)),
("plugins.context_switchers", plugin_info(self._plugins.context_switchers)), ('plugins.context_switchers', plugin_info(self._plugins.context_switchers)),
("configs_attempted", self.config.attempted_config_files), ('configs_attempted', self.config.attempted_config_files),
("configs_read", self.config.config_files_read), ('configs_read', self.config.config_files_read),
("config_file", self.config.config_file), ('config_file', self.config.config_file),
("config_contents", (
repr(self.config._config_contents) if self.config._config_contents else "-none-", 'config_contents',
repr(self.config._config_contents) if self.config._config_contents else '-none-',
), ),
("data_file", self._data.data_filename() if self._data is not None else "-none-"), ('data_file', self._data.data_filename() if self._data is not None else '-none-'),
("python", sys.version.replace("\n", "")), ('python', sys.version.replace('\n', '')),
("platform", platform.platform()), ('platform', platform.platform()),
("implementation", platform.python_implementation()), ('implementation', platform.python_implementation()),
("executable", sys.executable), ('executable', sys.executable),
("def_encoding", sys.getdefaultencoding()), ('def_encoding', sys.getdefaultencoding()),
("fs_encoding", sys.getfilesystemencoding()), ('fs_encoding', sys.getfilesystemencoding()),
("pid", os.getpid()), ('pid', os.getpid()),
("cwd", os.getcwd()), ('cwd', os.getcwd()),
("path", sys.path), ('path', sys.path),
("environment", [f"{k} = {v}" for k, v in relevant_environment_display(os.environ)]), ('environment', [f'{k} = {v}' for k, v in relevant_environment_display(os.environ)]),
("command_line", " ".join(getattr(sys, "argv", ["-none-"]))), ('command_line', ' '.join(getattr(sys, 'argv', ['-none-']))),
] ]
if self._inorout is not None: if self._inorout is not None:
@ -1332,12 +1353,12 @@ class Coverage(TConfigurable):
# Mega debugging... # Mega debugging...
# $set_env.py: COVERAGE_DEBUG_CALLS - Lots and lots of output about calls to Coverage. # $set_env.py: COVERAGE_DEBUG_CALLS - Lots and lots of output about calls to Coverage.
if int(os.getenv("COVERAGE_DEBUG_CALLS", 0)): # pragma: debugging if int(os.getenv('COVERAGE_DEBUG_CALLS', 0)): # pragma: debugging
from coverage.debug import decorate_methods, show_calls from coverage.debug import decorate_methods, show_calls
Coverage = decorate_methods( # type: ignore[misc] Coverage = decorate_methods( # type: ignore[misc]
show_calls(show_args=True), show_calls(show_args=True),
butnot=["get_data"], butnot=['get_data'],
)(Coverage) )(Coverage)
@ -1364,7 +1385,7 @@ def process_startup() -> Coverage | None:
not started by this call. not started by this call.
""" """
cps = os.getenv("COVERAGE_PROCESS_START") cps = os.getenv('COVERAGE_PROCESS_START')
if not cps: if not cps:
# No request for coverage, nothing to do. # No request for coverage, nothing to do.
return None return None
@ -1378,7 +1399,7 @@ def process_startup() -> Coverage | None:
# #
# https://github.com/nedbat/coveragepy/issues/340 has more details. # https://github.com/nedbat/coveragepy/issues/340 has more details.
if hasattr(process_startup, "coverage"): if hasattr(process_startup, 'coverage'):
# We've annotated this function before, so we must have already # We've annotated this function before, so we must have already
# started coverage.py in this process. Nothing to do. # started coverage.py in this process. Nothing to do.
return None return None
@ -1396,6 +1417,6 @@ def process_startup() -> Coverage | None:
def _prevent_sub_process_measurement() -> None: def _prevent_sub_process_measurement() -> None:
"""Stop any subprocess auto-measurement from writing data.""" """Stop any subprocess auto-measurement from writing data."""
auto_created_coverage = getattr(process_startup, "coverage", None) auto_created_coverage = getattr(process_startup, 'coverage', None)
if auto_created_coverage is not None: if auto_created_coverage is not None:
auto_created_coverage._auto_save = False auto_created_coverage._auto_save = False

View file

@ -1,6 +1,5 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Coverage data for coverage.py. """Coverage data for coverage.py.
This file had the 4.x JSON data support, which is now gone. This file still This file had the 4.x JSON data support, which is now gone. This file still
@ -9,18 +8,21 @@ CoverageData is now defined in sqldata.py, and imported here to keep the
imports working. imports working.
""" """
from __future__ import annotations from __future__ import annotations
import glob import glob
import hashlib import hashlib
import os.path import os.path
from typing import Callable
from typing import Iterable
from typing import Callable, Iterable from coverage.exceptions import CoverageException
from coverage.exceptions import NoDataError
from coverage.exceptions import CoverageException, NoDataError
from coverage.files import PathAliases from coverage.files import PathAliases
from coverage.misc import Hasher, file_be_gone, human_sorted, plural from coverage.misc import file_be_gone
from coverage.misc import Hasher
from coverage.misc import human_sorted
from coverage.misc import plural
from coverage.sqldata import CoverageData from coverage.sqldata import CoverageData
@ -38,7 +40,7 @@ def line_counts(data: CoverageData, fullpath: bool = False) -> dict[str, int]:
filename_fn: Callable[[str], str] filename_fn: Callable[[str], str]
if fullpath: if fullpath:
# pylint: disable=unnecessary-lambda-assignment # pylint: disable=unnecessary-lambda-assignment
filename_fn = lambda f: f def filename_fn(f): return f
else: else:
filename_fn = os.path.basename filename_fn = os.path.basename
for filename in data.measured_files(): for filename in data.measured_files():
@ -79,14 +81,14 @@ def combinable_files(data_file: str, data_paths: Iterable[str] | None = None) ->
if os.path.isfile(p): if os.path.isfile(p):
files_to_combine.append(os.path.abspath(p)) files_to_combine.append(os.path.abspath(p))
elif os.path.isdir(p): elif os.path.isdir(p):
pattern = glob.escape(os.path.join(os.path.abspath(p), local)) +".*" pattern = glob.escape(os.path.join(os.path.abspath(p), local)) + '.*'
files_to_combine.extend(glob.glob(pattern)) files_to_combine.extend(glob.glob(pattern))
else: else:
raise NoDataError(f"Couldn't combine from non-existent path '{p}'") raise NoDataError(f"Couldn't combine from non-existent path '{p}'")
# SQLite might have made journal files alongside our database files. # SQLite might have made journal files alongside our database files.
# We never want to combine those. # We never want to combine those.
files_to_combine = [fnm for fnm in files_to_combine if not fnm.endswith("-journal")] files_to_combine = [fnm for fnm in files_to_combine if not fnm.endswith('-journal')]
# Sorting isn't usually needed, since it shouldn't matter what order files # Sorting isn't usually needed, since it shouldn't matter what order files
# are combined, but sorting makes tests more predictable, and makes # are combined, but sorting makes tests more predictable, and makes
@ -132,7 +134,7 @@ def combine_parallel_data(
files_to_combine = combinable_files(data.base_filename(), data_paths) files_to_combine = combinable_files(data.base_filename(), data_paths)
if strict and not files_to_combine: if strict and not files_to_combine:
raise NoDataError("No data to combine") raise NoDataError('No data to combine')
file_hashes = set() file_hashes = set()
combined_any = False combined_any = False
@ -141,8 +143,8 @@ def combine_parallel_data(
if f == data.data_filename(): if f == data.data_filename():
# Sometimes we are combining into a file which is one of the # Sometimes we are combining into a file which is one of the
# parallel files. Skip that file. # parallel files. Skip that file.
if data._debug.should("dataio"): if data._debug.should('dataio'):
data._debug.write(f"Skipping combining ourself: {f!r}") data._debug.write(f'Skipping combining ourself: {f!r}')
continue continue
try: try:
@ -153,16 +155,16 @@ def combine_parallel_data(
# we print the original value of f instead of its relative path # we print the original value of f instead of its relative path
rel_file_name = f rel_file_name = f
with open(f, "rb") as fobj: with open(f, 'rb') as fobj:
hasher = hashlib.new("sha3_256") hasher = hashlib.new('sha3_256')
hasher.update(fobj.read()) hasher.update(fobj.read())
sha = hasher.digest() sha = hasher.digest()
combine_this_one = sha not in file_hashes combine_this_one = sha not in file_hashes
delete_this_one = not keep delete_this_one = not keep
if combine_this_one: if combine_this_one:
if data._debug.should("dataio"): if data._debug.should('dataio'):
data._debug.write(f"Combining data file {f!r}") data._debug.write(f'Combining data file {f!r}')
file_hashes.add(sha) file_hashes.add(sha)
try: try:
new_data = CoverageData(f, debug=data._debug) new_data = CoverageData(f, debug=data._debug)
@ -179,39 +181,39 @@ def combine_parallel_data(
data.update(new_data, aliases=aliases) data.update(new_data, aliases=aliases)
combined_any = True combined_any = True
if message: if message:
message(f"Combined data file {rel_file_name}") message(f'Combined data file {rel_file_name}')
else: else:
if message: if message:
message(f"Skipping duplicate data {rel_file_name}") message(f'Skipping duplicate data {rel_file_name}')
if delete_this_one: if delete_this_one:
if data._debug.should("dataio"): if data._debug.should('dataio'):
data._debug.write(f"Deleting data file {f!r}") data._debug.write(f'Deleting data file {f!r}')
file_be_gone(f) file_be_gone(f)
if strict and not combined_any: if strict and not combined_any:
raise NoDataError("No usable data files") raise NoDataError('No usable data files')
def debug_data_file(filename: str) -> None: def debug_data_file(filename: str) -> None:
"""Implementation of 'coverage debug data'.""" """Implementation of 'coverage debug data'."""
data = CoverageData(filename) data = CoverageData(filename)
filename = data.data_filename() filename = data.data_filename()
print(f"path: {filename}") print(f'path: {filename}')
if not os.path.exists(filename): if not os.path.exists(filename):
print("No data collected: file doesn't exist") print("No data collected: file doesn't exist")
return return
data.read() data.read()
print(f"has_arcs: {data.has_arcs()!r}") print(f'has_arcs: {data.has_arcs()!r}')
summary = line_counts(data, fullpath=True) summary = line_counts(data, fullpath=True)
filenames = human_sorted(summary.keys()) filenames = human_sorted(summary.keys())
nfiles = len(filenames) nfiles = len(filenames)
print(f"{nfiles} file{plural(nfiles)}:") print(f'{nfiles} file{plural(nfiles)}:')
for f in filenames: for f in filenames:
line = f"{f}: {summary[f]} line{plural(summary[f])}" line = f'{f}: {summary[f]} line{plural(summary[f])}'
plugin = data.file_tracer(f) plugin = data.file_tracer(f)
if plugin: if plugin:
line += f" [{plugin}]" line += f' [{plugin}]'
print(line) print(line)

View file

@ -1,10 +1,9 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Control of and utilities for debugging.""" """Control of and utilities for debugging."""
from __future__ import annotations from __future__ import annotations
import _thread
import atexit import atexit
import contextlib import contextlib
import functools import functools
@ -17,15 +16,18 @@ import reprlib
import sys import sys
import traceback import traceback
import types import types
import _thread from typing import Any
from typing import Callable
from typing import IO
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import overload
from typing import ( from coverage.misc import human_sorted_items
overload, from coverage.misc import isolate_module
Any, Callable, IO, Iterable, Iterator, Mapping, from coverage.types import AnyCallable
) from coverage.types import TWritable
from coverage.misc import human_sorted_items, isolate_module
from coverage.types import AnyCallable, TWritable
os = isolate_module(os) os = isolate_module(os)
@ -53,12 +55,12 @@ class DebugControl:
self.suppress_callers = False self.suppress_callers = False
filters = [] filters = []
if self.should("process"): if self.should('process'):
filters.append(CwdTracker().filter) filters.append(CwdTracker().filter)
filters.append(ProcessTracker().filter) filters.append(ProcessTracker().filter)
if self.should("pytest"): if self.should('pytest'):
filters.append(PytestTracker().filter) filters.append(PytestTracker().filter)
if self.should("pid"): if self.should('pid'):
filters.append(add_pid_and_tid) filters.append(add_pid_and_tid)
self.output = DebugOutputFile.get_one( self.output = DebugOutputFile.get_one(
@ -69,11 +71,11 @@ class DebugControl:
self.raw_output = self.output.outfile self.raw_output = self.output.outfile
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<DebugControl options={self.options!r} raw_output={self.raw_output!r}>" return f'<DebugControl options={self.options!r} raw_output={self.raw_output!r}>'
def should(self, option: str) -> bool: def should(self, option: str) -> bool:
"""Decide whether to output debug information in category `option`.""" """Decide whether to output debug information in category `option`."""
if option == "callers" and self.suppress_callers: if option == 'callers' and self.suppress_callers:
return False return False
return (option in self.options) return (option in self.options)
@ -96,20 +98,21 @@ class DebugControl:
after the message. after the message.
""" """
self.output.write(msg + "\n") self.output.write(msg + '\n')
if exc is not None: if exc is not None:
self.output.write("".join(traceback.format_exception(None, exc, exc.__traceback__))) self.output.write(''.join(traceback.format_exception(None, exc, exc.__traceback__)))
if self.should("self"): if self.should('self'):
caller_self = inspect.stack()[1][0].f_locals.get("self") caller_self = inspect.stack()[1][0].f_locals.get('self')
if caller_self is not None: if caller_self is not None:
self.output.write(f"self: {caller_self!r}\n") self.output.write(f'self: {caller_self!r}\n')
if self.should("callers"): if self.should('callers'):
dump_stack_frames(out=self.output, skip=1) dump_stack_frames(out=self.output, skip=1)
self.output.flush() self.output.flush()
class NoDebugging(DebugControl): class NoDebugging(DebugControl):
"""A replacement for DebugControl that will never try to do anything.""" """A replacement for DebugControl that will never try to do anything."""
def __init__(self) -> None: def __init__(self) -> None:
# pylint: disable=super-init-not-called # pylint: disable=super-init-not-called
... ...
@ -120,12 +123,12 @@ class NoDebugging(DebugControl):
def write(self, msg: str, *, exc: BaseException | None = None) -> None: def write(self, msg: str, *, exc: BaseException | None = None) -> None:
"""This will never be called.""" """This will never be called."""
raise AssertionError("NoDebugging.write should never be called.") raise AssertionError('NoDebugging.write should never be called.')
def info_header(label: str) -> str: def info_header(label: str) -> str:
"""Make a nice header string.""" """Make a nice header string."""
return "--{:-<60s}".format(" "+label+" ") return '--{:-<60s}'.format(' ' + label + ' ')
def info_formatter(info: Iterable[tuple[str, Any]]) -> Iterator[str]: def info_formatter(info: Iterable[tuple[str, Any]]) -> Iterator[str]:
@ -142,17 +145,17 @@ def info_formatter(info: Iterable[tuple[str, Any]]) -> Iterator[str]:
assert all(len(l) < label_len for l, _ in info) assert all(len(l) < label_len for l, _ in info)
for label, data in info: for label, data in info:
if data == []: if data == []:
data = "-none-" data = '-none-'
if isinstance(data, tuple) and len(repr(tuple(data))) < 30: if isinstance(data, tuple) and len(repr(tuple(data))) < 30:
# Convert to tuple to scrub namedtuples. # Convert to tuple to scrub namedtuples.
yield "%*s: %r" % (label_len, label, tuple(data)) yield '%*s: %r' % (label_len, label, tuple(data))
elif isinstance(data, (list, set, tuple)): elif isinstance(data, (list, set, tuple)):
prefix = "%*s:" % (label_len, label) prefix = '%*s:' % (label_len, label)
for e in data: for e in data:
yield "%*s %s" % (label_len+1, prefix, e) yield '%*s %s' % (label_len + 1, prefix, e)
prefix = "" prefix = ''
else: else:
yield "%*s: %s" % (label_len, label, data) yield '%*s: %s' % (label_len, label, data)
def write_formatted_info( def write_formatted_info(
@ -170,35 +173,38 @@ def write_formatted_info(
""" """
write(info_header(header)) write(info_header(header))
for line in info_formatter(info): for line in info_formatter(info):
write(f" {line}") write(f' {line}')
def exc_one_line(exc: Exception) -> str: def exc_one_line(exc: Exception) -> str:
"""Get a one-line summary of an exception, including class name and message.""" """Get a one-line summary of an exception, including class name and message."""
lines = traceback.format_exception_only(type(exc), exc) lines = traceback.format_exception_only(type(exc), exc)
return "|".join(l.rstrip() for l in lines) return '|'.join(l.rstrip() for l in lines)
_FILENAME_REGEXES: list[tuple[str, str]] = [ _FILENAME_REGEXES: list[tuple[str, str]] = [
(r".*[/\\]pytest-of-.*[/\\]pytest-\d+([/\\]popen-gw\d+)?", "tmp:"), (r'.*[/\\]pytest-of-.*[/\\]pytest-\d+([/\\]popen-gw\d+)?', 'tmp:'),
] ]
_FILENAME_SUBS: list[tuple[str, str]] = [] _FILENAME_SUBS: list[tuple[str, str]] = []
@overload @overload
def short_filename(filename: str) -> str: def short_filename(filename: str) -> str:
pass pass
@overload @overload
def short_filename(filename: None) -> None: def short_filename(filename: None) -> None:
pass pass
def short_filename(filename: str | None) -> str | None: def short_filename(filename: str | None) -> str | None:
"""Shorten a file name. Directories are replaced by prefixes like 'syspath:'""" """Shorten a file name. Directories are replaced by prefixes like 'syspath:'"""
if not _FILENAME_SUBS: if not _FILENAME_SUBS:
for pathdir in sys.path: for pathdir in sys.path:
_FILENAME_SUBS.append((pathdir, "syspath:")) _FILENAME_SUBS.append((pathdir, 'syspath:'))
import coverage import coverage
_FILENAME_SUBS.append((os.path.dirname(coverage.__file__), "cov:")) _FILENAME_SUBS.append((os.path.dirname(coverage.__file__), 'cov:'))
_FILENAME_SUBS.sort(key=(lambda pair: len(pair[0])), reverse=True) _FILENAME_SUBS.sort(key=(lambda pair: len(pair[0])), reverse=True)
if filename is not None: if filename is not None:
for pat, sub in _FILENAME_REGEXES: for pat, sub in _FILENAME_REGEXES:
@ -237,9 +243,9 @@ def short_stack(
""" """
# Regexes in initial frames that we don't care about. # Regexes in initial frames that we don't care about.
BORING_PRELUDE = [ BORING_PRELUDE = [
"<string>", # pytest-xdist has string execution. '<string>', # pytest-xdist has string execution.
r"\bigor.py$", # Our test runner. r'\bigor.py$', # Our test runner.
r"\bsite-packages\b", # pytest etc getting to our tests. r'\bsite-packages\b', # pytest etc getting to our tests.
] ]
stack: Iterable[inspect.FrameInfo] = inspect.stack()[:skip:-1] stack: Iterable[inspect.FrameInfo] = inspect.stack()[:skip:-1]
@ -251,20 +257,20 @@ def short_stack(
) )
lines = [] lines = []
for frame_info in stack: for frame_info in stack:
line = f"{frame_info.function:>30s} : " line = f'{frame_info.function:>30s} : '
if frame_ids: if frame_ids:
line += f"{id(frame_info.frame):#x} " line += f'{id(frame_info.frame):#x} '
filename = frame_info.filename filename = frame_info.filename
if short_filenames: if short_filenames:
filename = short_filename(filename) filename = short_filename(filename)
line += f"{filename}:{frame_info.lineno}" line += f'{filename}:{frame_info.lineno}'
lines.append(line) lines.append(line)
return "\n".join(lines) return '\n'.join(lines)
def dump_stack_frames(out: TWritable, skip: int = 0) -> None: def dump_stack_frames(out: TWritable, skip: int = 0) -> None:
"""Print a summary of the stack to `out`.""" """Print a summary of the stack to `out`."""
out.write(short_stack(skip=skip+1) + "\n") out.write(short_stack(skip=skip + 1) + '\n')
def clipped_repr(text: str, numchars: int = 50) -> str: def clipped_repr(text: str, numchars: int = 50) -> str:
@ -285,36 +291,37 @@ def short_id(id64: int) -> int:
def add_pid_and_tid(text: str) -> str: def add_pid_and_tid(text: str) -> str:
"""A filter to add pid and tid to debug messages.""" """A filter to add pid and tid to debug messages."""
# Thread ids are useful, but too long. Make a shorter one. # Thread ids are useful, but too long. Make a shorter one.
tid = f"{short_id(_thread.get_ident()):04x}" tid = f'{short_id(_thread.get_ident()):04x}'
text = f"{os.getpid():5d}.{tid}: {text}" text = f'{os.getpid():5d}.{tid}: {text}'
return text return text
AUTO_REPR_IGNORE = {"$coverage.object_id"} AUTO_REPR_IGNORE = {'$coverage.object_id'}
def auto_repr(self: Any) -> str: def auto_repr(self: Any) -> str:
"""A function implementing an automatic __repr__ for debugging.""" """A function implementing an automatic __repr__ for debugging."""
show_attrs = ( show_attrs = (
(k, v) for k, v in self.__dict__.items() (k, v) for k, v in self.__dict__.items()
if getattr(v, "show_repr_attr", True) if getattr(v, 'show_repr_attr', True) and
and not inspect.ismethod(v) not inspect.ismethod(v) and
and k not in AUTO_REPR_IGNORE k not in AUTO_REPR_IGNORE
) )
return "<{klass} @{id:#x}{attrs}>".format( return '<{klass} @{id:#x}{attrs}>'.format(
klass=self.__class__.__name__, klass=self.__class__.__name__,
id=id(self), id=id(self),
attrs="".join(f" {k}={v!r}" for k, v in show_attrs), attrs=''.join(f' {k}={v!r}' for k, v in show_attrs),
) )
def simplify(v: Any) -> Any: # pragma: debugging def simplify(v: Any) -> Any: # pragma: debugging
"""Turn things which are nearly dict/list/etc into dict/list/etc.""" """Turn things which are nearly dict/list/etc into dict/list/etc."""
if isinstance(v, dict): if isinstance(v, dict):
return {k:simplify(vv) for k, vv in v.items()} return {k: simplify(vv) for k, vv in v.items()}
elif isinstance(v, (list, tuple)): elif isinstance(v, (list, tuple)):
return type(v)(simplify(vv) for vv in v) return type(v)(simplify(vv) for vv in v)
elif hasattr(v, "__dict__"): elif hasattr(v, '__dict__'):
return simplify({"."+k: v for k, v in v.__dict__.items()}) return simplify({'.' + k: v for k, v in v.__dict__.items()})
else: else:
return v return v
@ -343,12 +350,13 @@ def filter_text(text: str, filters: Iterable[Callable[[str], str]]) -> str:
lines = [] lines = []
for line in text.splitlines(): for line in text.splitlines():
lines.extend(filter_fn(line).splitlines()) lines.extend(filter_fn(line).splitlines())
text = "\n".join(lines) text = '\n'.join(lines)
return text + ending return text + ending
class CwdTracker: class CwdTracker:
"""A class to add cwd info to debug messages.""" """A class to add cwd info to debug messages."""
def __init__(self) -> None: def __init__(self) -> None:
self.cwd: str | None = None self.cwd: str | None = None
@ -356,32 +364,33 @@ class CwdTracker:
"""Add a cwd message for each new cwd.""" """Add a cwd message for each new cwd."""
cwd = os.getcwd() cwd = os.getcwd()
if cwd != self.cwd: if cwd != self.cwd:
text = f"cwd is now {cwd!r}\n" + text text = f'cwd is now {cwd!r}\n' + text
self.cwd = cwd self.cwd = cwd
return text return text
class ProcessTracker: class ProcessTracker:
"""Track process creation for debug logging.""" """Track process creation for debug logging."""
def __init__(self) -> None: def __init__(self) -> None:
self.pid: int = os.getpid() self.pid: int = os.getpid()
self.did_welcome = False self.did_welcome = False
def filter(self, text: str) -> str: def filter(self, text: str) -> str:
"""Add a message about how new processes came to be.""" """Add a message about how new processes came to be."""
welcome = "" welcome = ''
pid = os.getpid() pid = os.getpid()
if self.pid != pid: if self.pid != pid:
welcome = f"New process: forked {self.pid} -> {pid}\n" welcome = f'New process: forked {self.pid} -> {pid}\n'
self.pid = pid self.pid = pid
elif not self.did_welcome: elif not self.did_welcome:
argv = getattr(sys, "argv", None) argv = getattr(sys, 'argv', None)
welcome = ( welcome = (
f"New process: {pid=}, executable: {sys.executable!r}\n" f'New process: {pid=}, executable: {sys.executable!r}\n' +
+ f"New process: cmd: {argv!r}\n" f'New process: cmd: {argv!r}\n'
) )
if hasattr(os, "getppid"): if hasattr(os, 'getppid'):
welcome += f"New process parent pid: {os.getppid()!r}\n" welcome += f'New process parent pid: {os.getppid()!r}\n'
if welcome: if welcome:
self.did_welcome = True self.did_welcome = True
@ -392,20 +401,22 @@ class ProcessTracker:
class PytestTracker: class PytestTracker:
"""Track the current pytest test name to add to debug messages.""" """Track the current pytest test name to add to debug messages."""
def __init__(self) -> None: def __init__(self) -> None:
self.test_name: str | None = None self.test_name: str | None = None
def filter(self, text: str) -> str: def filter(self, text: str) -> str:
"""Add a message when the pytest test changes.""" """Add a message when the pytest test changes."""
test_name = os.getenv("PYTEST_CURRENT_TEST") test_name = os.getenv('PYTEST_CURRENT_TEST')
if test_name != self.test_name: if test_name != self.test_name:
text = f"Pytest context: {test_name}\n" + text text = f'Pytest context: {test_name}\n' + text
self.test_name = test_name self.test_name = test_name
return text return text
class DebugOutputFile: class DebugOutputFile:
"""A file-like object that includes pid and cwd information.""" """A file-like object that includes pid and cwd information."""
def __init__( def __init__(
self, self,
outfile: IO[str] | None, outfile: IO[str] | None,
@ -444,21 +455,21 @@ class DebugOutputFile:
the_one, is_interim = cls._get_singleton_data() the_one, is_interim = cls._get_singleton_data()
if the_one is None or is_interim: if the_one is None or is_interim:
if file_name is not None: if file_name is not None:
fileobj = open(file_name, "a", encoding="utf-8") fileobj = open(file_name, 'a', encoding='utf-8')
else: else:
# $set_env.py: COVERAGE_DEBUG_FILE - Where to write debug output # $set_env.py: COVERAGE_DEBUG_FILE - Where to write debug output
file_name = os.getenv("COVERAGE_DEBUG_FILE", FORCED_DEBUG_FILE) file_name = os.getenv('COVERAGE_DEBUG_FILE', FORCED_DEBUG_FILE)
if file_name in ("stdout", "stderr"): if file_name in ('stdout', 'stderr'):
fileobj = getattr(sys, file_name) fileobj = getattr(sys, file_name)
elif file_name: elif file_name:
fileobj = open(file_name, "a", encoding="utf-8") fileobj = open(file_name, 'a', encoding='utf-8')
atexit.register(fileobj.close) atexit.register(fileobj.close)
else: else:
fileobj = sys.stderr fileobj = sys.stderr
the_one = cls(fileobj, filters) the_one = cls(fileobj, filters)
cls._set_singleton_data(the_one, interim) cls._set_singleton_data(the_one, interim)
if not(the_one.filters): if not (the_one.filters):
the_one.filters = list(filters) the_one.filters = list(filters)
return the_one return the_one
@ -467,8 +478,8 @@ class DebugOutputFile:
# a process-wide singleton. So stash it in sys.modules instead of # a process-wide singleton. So stash it in sys.modules instead of
# on a class attribute. Yes, this is aggressively gross. # on a class attribute. Yes, this is aggressively gross.
SYS_MOD_NAME = "$coverage.debug.DebugOutputFile.the_one" SYS_MOD_NAME = '$coverage.debug.DebugOutputFile.the_one'
SINGLETON_ATTR = "the_one_and_is_interim" SINGLETON_ATTR = 'the_one_and_is_interim'
@classmethod @classmethod
def _set_singleton_data(cls, the_one: DebugOutputFile, interim: bool) -> None: def _set_singleton_data(cls, the_one: DebugOutputFile, interim: bool) -> None:
@ -504,7 +515,7 @@ class DebugOutputFile:
def log(msg: str, stack: bool = False) -> None: # pragma: debugging def log(msg: str, stack: bool = False) -> None: # pragma: debugging
"""Write a log message as forcefully as possible.""" """Write a log message as forcefully as possible."""
out = DebugOutputFile.get_one(interim=True) out = DebugOutputFile.get_one(interim=True)
out.write(msg+"\n") out.write(msg + '\n')
if stack: if stack:
dump_stack_frames(out=out, skip=1) dump_stack_frames(out=out, skip=1)
@ -519,8 +530,8 @@ def decorate_methods(
for name, meth in inspect.getmembers(cls, inspect.isroutine): for name, meth in inspect.getmembers(cls, inspect.isroutine):
if name not in cls.__dict__: if name not in cls.__dict__:
continue continue
if name != "__init__": if name != '__init__':
if not private and name.startswith("_"): if not private and name.startswith('_'):
continue continue
if name in butnot: if name in butnot:
continue continue
@ -542,7 +553,8 @@ def break_in_pudb(func: AnyCallable) -> AnyCallable: # pragma: debugging
OBJ_IDS = itertools.count() OBJ_IDS = itertools.count()
CALLS = itertools.count() CALLS = itertools.count()
OBJ_ID_ATTR = "$coverage.object_id" OBJ_ID_ATTR = '$coverage.object_id'
def show_calls( def show_calls(
show_args: bool = True, show_args: bool = True,
@ -555,27 +567,27 @@ def show_calls(
def _wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: def _wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
oid = getattr(self, OBJ_ID_ATTR, None) oid = getattr(self, OBJ_ID_ATTR, None)
if oid is None: if oid is None:
oid = f"{os.getpid():08d} {next(OBJ_IDS):04d}" oid = f'{os.getpid():08d} {next(OBJ_IDS):04d}'
setattr(self, OBJ_ID_ATTR, oid) setattr(self, OBJ_ID_ATTR, oid)
extra = "" extra = ''
if show_args: if show_args:
eargs = ", ".join(map(repr, args)) eargs = ', '.join(map(repr, args))
ekwargs = ", ".join("{}={!r}".format(*item) for item in kwargs.items()) ekwargs = ', '.join('{}={!r}'.format(*item) for item in kwargs.items())
extra += "(" extra += '('
extra += eargs extra += eargs
if eargs and ekwargs: if eargs and ekwargs:
extra += ", " extra += ', '
extra += ekwargs extra += ekwargs
extra += ")" extra += ')'
if show_stack: if show_stack:
extra += " @ " extra += ' @ '
extra += "; ".join(short_stack(short_filenames=True).splitlines()) extra += '; '.join(short_stack(short_filenames=True).splitlines())
callid = next(CALLS) callid = next(CALLS)
msg = f"{oid} {callid:04d} {func.__name__}{extra}\n" msg = f'{oid} {callid:04d} {func.__name__}{extra}\n'
DebugOutputFile.get_one(interim=True).write(msg) DebugOutputFile.get_one(interim=True).write(msg)
ret = func(self, *args, **kwargs) ret = func(self, *args, **kwargs)
if show_return: if show_return:
msg = f"{oid} {callid:04d} {func.__name__} return {ret!r}\n" msg = f'{oid} {callid:04d} {func.__name__} return {ret!r}\n'
DebugOutputFile.get_one(interim=True).write(msg) DebugOutputFile.get_one(interim=True).write(msg)
return ret return ret
return _wrapper return _wrapper
@ -595,9 +607,9 @@ def relevant_environment_display(env: Mapping[str, str]) -> list[tuple[str, str]
A list of pairs (name, value) to show. A list of pairs (name, value) to show.
""" """
slugs = {"COV", "PY"} slugs = {'COV', 'PY'}
include = {"HOME", "TEMP", "TMP"} include = {'HOME', 'TEMP', 'TMP'}
cloak = {"API", "TOKEN", "KEY", "SECRET", "PASS", "SIGNATURE"} cloak = {'API', 'TOKEN', 'KEY', 'SECRET', 'PASS', 'SIGNATURE'}
to_show = [] to_show = []
for name, val in env.items(): for name, val in env.items():
@ -608,6 +620,6 @@ def relevant_environment_display(env: Mapping[str, str]) -> list[tuple[str, str]
keep = True keep = True
if keep: if keep:
if any(slug in name for slug in cloak): if any(slug in name for slug in cloak):
val = re.sub(r"\w", "*", val) val = re.sub(r'\w', '*', val)
to_show.append((name, val)) to_show.append((name, val))
return human_sorted_items(to_show) return human_sorted_items(to_show)

View file

@ -1,8 +1,6 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Simple value objects for tracking what to do with files.""" """Simple value objects for tracking what to do with files."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -25,7 +23,7 @@ class FileDisposition:
has_dynamic_filename: bool has_dynamic_filename: bool
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<FileDisposition {self.canonical_filename!r}: trace={self.trace}>" return f'<FileDisposition {self.canonical_filename!r}: trace={self.trace}>'
# FileDisposition "methods": FileDisposition is a pure value object, so it can # FileDisposition "methods": FileDisposition is a pure value object, so it can
@ -39,7 +37,7 @@ def disposition_init(cls: type[TFileDisposition], original_filename: str) -> TFi
disp.canonical_filename = original_filename disp.canonical_filename = original_filename
disp.source_filename = None disp.source_filename = None
disp.trace = False disp.trace = False
disp.reason = "" disp.reason = ''
disp.file_tracer = None disp.file_tracer = None
disp.has_dynamic_filename = False disp.has_dynamic_filename = False
return disp return disp
@ -48,11 +46,11 @@ def disposition_init(cls: type[TFileDisposition], original_filename: str) -> TFi
def disposition_debug_msg(disp: TFileDisposition) -> str: def disposition_debug_msg(disp: TFileDisposition) -> str:
"""Make a nice debug message of what the FileDisposition is doing.""" """Make a nice debug message of what the FileDisposition is doing."""
if disp.trace: if disp.trace:
msg = f"Tracing {disp.original_filename!r}" msg = f'Tracing {disp.original_filename!r}'
if disp.original_filename != disp.source_filename: if disp.original_filename != disp.source_filename:
msg += f" as {disp.source_filename!r}" msg += f' as {disp.source_filename!r}'
if disp.file_tracer: if disp.file_tracer:
msg += f": will be traced by {disp.file_tracer!r}" msg += f': will be traced by {disp.file_tracer!r}'
else: else:
msg = f"Not tracing {disp.original_filename!r}: {disp.reason}" msg = f'Not tracing {disp.original_filename!r}: {disp.reason}'
return msg return msg

View file

@ -1,48 +1,48 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Determine facts about the environment.""" """Determine facts about the environment."""
from __future__ import annotations from __future__ import annotations
import os import os
import platform import platform
import sys import sys
from typing import Any
from typing import Any, Iterable from typing import Iterable
# debug_info() at the bottom wants to show all the globals, but not imports. # debug_info() at the bottom wants to show all the globals, but not imports.
# Grab the global names here to know which names to not show. Nothing defined # Grab the global names here to know which names to not show. Nothing defined
# above this line will be in the output. # above this line will be in the output.
_UNINTERESTING_GLOBALS = list(globals()) _UNINTERESTING_GLOBALS = list(globals())
# These names also shouldn't be shown. # These names also shouldn't be shown.
_UNINTERESTING_GLOBALS += ["PYBEHAVIOR", "debug_info"] _UNINTERESTING_GLOBALS += ['PYBEHAVIOR', 'debug_info']
# Operating systems. # Operating systems.
WINDOWS = sys.platform == "win32" WINDOWS = sys.platform == 'win32'
LINUX = sys.platform.startswith("linux") LINUX = sys.platform.startswith('linux')
OSX = sys.platform == "darwin" OSX = sys.platform == 'darwin'
# Python implementations. # Python implementations.
CPYTHON = (platform.python_implementation() == "CPython") CPYTHON = (platform.python_implementation() == 'CPython')
PYPY = (platform.python_implementation() == "PyPy") PYPY = (platform.python_implementation() == 'PyPy')
# Python versions. We amend version_info with one more value, a zero if an # Python versions. We amend version_info with one more value, a zero if an
# official version, or 1 if built from source beyond an official version. # official version, or 1 if built from source beyond an official version.
# Only use sys.version_info directly where tools like mypy need it to understand # Only use sys.version_info directly where tools like mypy need it to understand
# version-specfic code, otherwise use PYVERSION. # version-specfic code, otherwise use PYVERSION.
PYVERSION = sys.version_info + (int(platform.python_version()[-1] == "+"),) PYVERSION = sys.version_info + (int(platform.python_version()[-1] == '+'),)
if PYPY: if PYPY:
PYPYVERSION = sys.pypy_version_info # type: ignore[attr-defined] PYPYVERSION = sys.pypy_version_info # type: ignore[attr-defined]
# Python behavior. # Python behavior.
class PYBEHAVIOR: class PYBEHAVIOR:
"""Flags indicating this Python's behavior.""" """Flags indicating this Python's behavior."""
# Does Python conform to PEP626, Precise line numbers for debugging and other tools. # Does Python conform to PEP626, Precise line numbers for debugging and other tools.
# https://www.python.org/dev/peps/pep-0626 # https://www.python.org/dev/peps/pep-0626
pep626 = (PYVERSION > (3, 10, 0, "alpha", 4)) pep626 = (PYVERSION > (3, 10, 0, 'alpha', 4))
# Is "if __debug__" optimized away? # Is "if __debug__" optimized away?
optimize_if_debug = not pep626 optimize_if_debug = not pep626
@ -69,19 +69,19 @@ class PYBEHAVIOR:
# CPython 3.11 now jumps to the decorator line again while executing # CPython 3.11 now jumps to the decorator line again while executing
# the decorator. # the decorator.
trace_decorator_line_again = (CPYTHON and PYVERSION > (3, 11, 0, "alpha", 3, 0)) trace_decorator_line_again = (CPYTHON and PYVERSION > (3, 11, 0, 'alpha', 3, 0))
# CPython 3.9a1 made sys.argv[0] and other reported files absolute paths. # CPython 3.9a1 made sys.argv[0] and other reported files absolute paths.
report_absolute_files = ( report_absolute_files = (
(CPYTHON or (PYPY and PYPYVERSION >= (7, 3, 10))) (CPYTHON or (PYPY and PYPYVERSION >= (7, 3, 10))) and
and PYVERSION >= (3, 9) PYVERSION >= (3, 9)
) )
# Lines after break/continue/return/raise are no longer compiled into the # Lines after break/continue/return/raise are no longer compiled into the
# bytecode. They used to be marked as missing, now they aren't executable. # bytecode. They used to be marked as missing, now they aren't executable.
omit_after_jump = ( omit_after_jump = (
pep626 pep626 or
or (PYPY and PYVERSION >= (3, 9) and PYPYVERSION >= (7, 3, 12)) (PYPY and PYVERSION >= (3, 9) and PYPYVERSION >= (7, 3, 12))
) )
# PyPy has always omitted statements after return. # PyPy has always omitted statements after return.
@ -98,7 +98,7 @@ class PYBEHAVIOR:
keep_constant_test = pep626 keep_constant_test = pep626
# When leaving a with-block, do we visit the with-line again for the exit? # When leaving a with-block, do we visit the with-line again for the exit?
exit_through_with = (PYVERSION >= (3, 10, 0, "beta")) exit_through_with = (PYVERSION >= (3, 10, 0, 'beta'))
# Match-case construct. # Match-case construct.
match_case = (PYVERSION >= (3, 10)) match_case = (PYVERSION >= (3, 10))
@ -108,14 +108,14 @@ class PYBEHAVIOR:
# Modules start with a line numbered zero. This means empty modules have # Modules start with a line numbered zero. This means empty modules have
# only a 0-number line, which is ignored, giving a truly empty module. # only a 0-number line, which is ignored, giving a truly empty module.
empty_is_empty = (PYVERSION >= (3, 11, 0, "beta", 4)) empty_is_empty = (PYVERSION >= (3, 11, 0, 'beta', 4))
# Are comprehensions inlined (new) or compiled as called functions (old)? # Are comprehensions inlined (new) or compiled as called functions (old)?
# Changed in https://github.com/python/cpython/pull/101441 # Changed in https://github.com/python/cpython/pull/101441
comprehensions_are_functions = (PYVERSION <= (3, 12, 0, "alpha", 7, 0)) comprehensions_are_functions = (PYVERSION <= (3, 12, 0, 'alpha', 7, 0))
# PEP669 Low Impact Monitoring: https://peps.python.org/pep-0669/ # PEP669 Low Impact Monitoring: https://peps.python.org/pep-0669/
pep669 = bool(getattr(sys, "monitoring", None)) pep669 = bool(getattr(sys, 'monitoring', None))
# Where does frame.f_lasti point when yielding from a generator? # Where does frame.f_lasti point when yielding from a generator?
# It used to point at the YIELD, now it points at the RESUME. # It used to point at the YIELD, now it points at the RESUME.
@ -126,22 +126,22 @@ class PYBEHAVIOR:
# Coverage.py specifics, about testing scenarios. See tests/testenv.py also. # Coverage.py specifics, about testing scenarios. See tests/testenv.py also.
# Are we coverage-measuring ourselves? # Are we coverage-measuring ourselves?
METACOV = os.getenv("COVERAGE_COVERAGE") is not None METACOV = os.getenv('COVERAGE_COVERAGE') is not None
# Are we running our test suite? # Are we running our test suite?
# Even when running tests, you can use COVERAGE_TESTING=0 to disable the # Even when running tests, you can use COVERAGE_TESTING=0 to disable the
# test-specific behavior like AST checking. # test-specific behavior like AST checking.
TESTING = os.getenv("COVERAGE_TESTING") == "True" TESTING = os.getenv('COVERAGE_TESTING') == 'True'
def debug_info() -> Iterable[tuple[str, Any]]: def debug_info() -> Iterable[tuple[str, Any]]:
"""Return a list of (name, value) pairs for printing debug information.""" """Return a list of (name, value) pairs for printing debug information."""
info = [ info = [
(name, value) for name, value in globals().items() (name, value) for name, value in globals().items()
if not name.startswith("_") and name not in _UNINTERESTING_GLOBALS if not name.startswith('_') and name not in _UNINTERESTING_GLOBALS
] ]
info += [ info += [
(name, value) for name, value in PYBEHAVIOR.__dict__.items() (name, value) for name, value in PYBEHAVIOR.__dict__.items()
if not name.startswith("_") if not name.startswith('_')
] ]
return sorted(info) return sorted(info)

View file

@ -1,10 +1,9 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Exceptions coverage.py can raise.""" """Exceptions coverage.py can raise."""
from __future__ import annotations from __future__ import annotations
class _BaseCoverageException(Exception): class _BaseCoverageException(Exception):
"""The base-base of all Coverage exceptions.""" """The base-base of all Coverage exceptions."""
pass pass
@ -24,6 +23,7 @@ class DataError(CoverageException):
"""An error in using a data file.""" """An error in using a data file."""
pass pass
class NoDataError(CoverageException): class NoDataError(CoverageException):
"""We didn't have data to work with.""" """We didn't have data to work with."""
pass pass

View file

@ -1,8 +1,6 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Execute files of Python code.""" """Execute files of Python code."""
from __future__ import annotations from __future__ import annotations
import importlib.machinery import importlib.machinery
@ -12,14 +10,18 @@ import marshal
import os import os
import struct import struct
import sys import sys
from importlib.machinery import ModuleSpec from importlib.machinery import ModuleSpec
from types import CodeType, ModuleType from types import CodeType
from types import ModuleType
from typing import Any from typing import Any
from coverage import env from coverage import env
from coverage.exceptions import CoverageException, _ExceptionDuringRun, NoCode, NoSource from coverage.exceptions import _ExceptionDuringRun
from coverage.files import canonical_filename, python_reported_file from coverage.exceptions import CoverageException
from coverage.exceptions import NoCode
from coverage.exceptions import NoSource
from coverage.files import canonical_filename
from coverage.files import python_reported_file
from coverage.misc import isolate_module from coverage.misc import isolate_module
from coverage.python import get_python_source from coverage.python import get_python_source
@ -28,11 +30,13 @@ os = isolate_module(os)
PYC_MAGIC_NUMBER = importlib.util.MAGIC_NUMBER PYC_MAGIC_NUMBER = importlib.util.MAGIC_NUMBER
class DummyLoader: class DummyLoader:
"""A shim for the pep302 __loader__, emulating pkgutil.ImpLoader. """A shim for the pep302 __loader__, emulating pkgutil.ImpLoader.
Currently only implements the .fullname attribute Currently only implements the .fullname attribute
""" """
def __init__(self, fullname: str, *_args: Any) -> None: def __init__(self, fullname: str, *_args: Any) -> None:
self.fullname = fullname self.fullname = fullname
@ -50,20 +54,20 @@ def find_module(
except ImportError as err: except ImportError as err:
raise NoSource(str(err)) from err raise NoSource(str(err)) from err
if not spec: if not spec:
raise NoSource(f"No module named {modulename!r}") raise NoSource(f'No module named {modulename!r}')
pathname = spec.origin pathname = spec.origin
packagename = spec.name packagename = spec.name
if spec.submodule_search_locations: if spec.submodule_search_locations:
mod_main = modulename + ".__main__" mod_main = modulename + '.__main__'
spec = importlib.util.find_spec(mod_main) spec = importlib.util.find_spec(mod_main)
if not spec: if not spec:
raise NoSource( raise NoSource(
f"No module named {mod_main}; " + f'No module named {mod_main}; ' +
f"{modulename!r} is a package and cannot be directly executed", f'{modulename!r} is a package and cannot be directly executed',
) )
pathname = spec.origin pathname = spec.origin
packagename = spec.name packagename = spec.name
packagename = packagename.rpartition(".")[0] packagename = packagename.rpartition('.')[0]
return pathname, packagename, spec return pathname, packagename, spec
@ -73,6 +77,7 @@ class PyRunner:
This is meant to emulate real Python execution as closely as possible. This is meant to emulate real Python execution as closely as possible.
""" """
def __init__(self, args: list[str], as_module: bool = False) -> None: def __init__(self, args: list[str], as_module: bool = False) -> None:
self.args = args self.args = args
self.as_module = as_module self.as_module = as_module
@ -142,8 +147,8 @@ class PyRunner:
elif os.path.isdir(self.arg0): elif os.path.isdir(self.arg0):
# Running a directory means running the __main__.py file in that # Running a directory means running the __main__.py file in that
# directory. # directory.
for ext in [".py", ".pyc", ".pyo"]: for ext in ['.py', '.pyc', '.pyo']:
try_filename = os.path.join(self.arg0, "__main__" + ext) try_filename = os.path.join(self.arg0, '__main__' + ext)
# 3.8.10 changed how files are reported when running a # 3.8.10 changed how files are reported when running a
# directory. But I'm not sure how far this change is going to # directory. But I'm not sure how far this change is going to
# spread, so I'll just hard-code it here for now. # spread, so I'll just hard-code it here for now.
@ -157,12 +162,12 @@ class PyRunner:
# Make a spec. I don't know if this is the right way to do it. # Make a spec. I don't know if this is the right way to do it.
try_filename = python_reported_file(try_filename) try_filename = python_reported_file(try_filename)
self.spec = importlib.machinery.ModuleSpec("__main__", None, origin=try_filename) self.spec = importlib.machinery.ModuleSpec('__main__', None, origin=try_filename)
self.spec.has_location = True self.spec.has_location = True
self.package = "" self.package = ''
self.loader = DummyLoader("__main__") self.loader = DummyLoader('__main__')
else: else:
self.loader = DummyLoader("__main__") self.loader = DummyLoader('__main__')
self.arg0 = python_reported_file(self.arg0) self.arg0 = python_reported_file(self.arg0)
@ -172,9 +177,9 @@ class PyRunner:
self._prepare2() self._prepare2()
# Create a module to serve as __main__ # Create a module to serve as __main__
main_mod = ModuleType("__main__") main_mod = ModuleType('__main__')
from_pyc = self.arg0.endswith((".pyc", ".pyo")) from_pyc = self.arg0.endswith(('.pyc', '.pyo'))
main_mod.__file__ = self.arg0 main_mod.__file__ = self.arg0
if from_pyc: if from_pyc:
main_mod.__file__ = main_mod.__file__[:-1] main_mod.__file__ = main_mod.__file__[:-1]
@ -184,9 +189,9 @@ class PyRunner:
if self.spec is not None: if self.spec is not None:
main_mod.__spec__ = self.spec main_mod.__spec__ = self.spec
main_mod.__builtins__ = sys.modules["builtins"] # type: ignore[attr-defined] main_mod.__builtins__ = sys.modules['builtins'] # type: ignore[attr-defined]
sys.modules["__main__"] = main_mod sys.modules['__main__'] = main_mod
# Set sys.argv properly. # Set sys.argv properly.
sys.argv = self.args sys.argv = self.args
@ -228,7 +233,7 @@ class PyRunner:
# is non-None when the exception is reported at the upper layer, # is non-None when the exception is reported at the upper layer,
# and a nested exception is shown to the user. This getattr fixes # and a nested exception is shown to the user. This getattr fixes
# it somehow? https://bitbucket.org/pypy/pypy/issue/1903 # it somehow? https://bitbucket.org/pypy/pypy/issue/1903
getattr(err, "__context__", None) getattr(err, '__context__', None)
# Call the excepthook. # Call the excepthook.
try: try:
@ -240,7 +245,7 @@ class PyRunner:
except Exception as exc: except Exception as exc:
# Getting the output right in the case of excepthook # Getting the output right in the case of excepthook
# shenanigans is kind of involved. # shenanigans is kind of involved.
sys.stderr.write("Error in sys.excepthook:\n") sys.stderr.write('Error in sys.excepthook:\n')
typ2, err2, tb2 = sys.exc_info() typ2, err2, tb2 = sys.exc_info()
assert typ2 is not None assert typ2 is not None
assert err2 is not None assert err2 is not None
@ -249,7 +254,7 @@ class PyRunner:
assert err2.__traceback__ is not None assert err2.__traceback__ is not None
err2.__traceback__ = err2.__traceback__.tb_next err2.__traceback__ = err2.__traceback__.tb_next
sys.__excepthook__(typ2, err2, tb2.tb_next) sys.__excepthook__(typ2, err2, tb2.tb_next)
sys.stderr.write("\nOriginal exception was:\n") sys.stderr.write('\nOriginal exception was:\n')
raise _ExceptionDuringRun(typ, err, tb.tb_next) from exc raise _ExceptionDuringRun(typ, err, tb.tb_next) from exc
else: else:
sys.exit(1) sys.exit(1)
@ -294,13 +299,13 @@ def make_code_from_py(filename: str) -> CodeType:
except (OSError, NoSource) as exc: except (OSError, NoSource) as exc:
raise NoSource(f"No file to run: '{filename}'") from exc raise NoSource(f"No file to run: '{filename}'") from exc
return compile(source, filename, "exec", dont_inherit=True) return compile(source, filename, 'exec', dont_inherit=True)
def make_code_from_pyc(filename: str) -> CodeType: def make_code_from_pyc(filename: str) -> CodeType:
"""Get a code object from a .pyc file.""" """Get a code object from a .pyc file."""
try: try:
fpyc = open(filename, "rb") fpyc = open(filename, 'rb')
except OSError as exc: except OSError as exc:
raise NoCode(f"No file to run: '{filename}'") from exc raise NoCode(f"No file to run: '{filename}'") from exc
@ -309,9 +314,9 @@ def make_code_from_pyc(filename: str) -> CodeType:
# match or we won't run the file. # match or we won't run the file.
magic = fpyc.read(4) magic = fpyc.read(4)
if magic != PYC_MAGIC_NUMBER: if magic != PYC_MAGIC_NUMBER:
raise NoCode(f"Bad magic number in .pyc file: {magic!r} != {PYC_MAGIC_NUMBER!r}") raise NoCode(f'Bad magic number in .pyc file: {magic!r} != {PYC_MAGIC_NUMBER!r}')
flags = struct.unpack("<L", fpyc.read(4))[0] flags = struct.unpack('<L', fpyc.read(4))[0]
hash_based = flags & 0x01 hash_based = flags & 0x01
if hash_based: if hash_based:
fpyc.read(8) # Skip the hash. fpyc.read(8) # Skip the hash.

View file

@ -1,31 +1,31 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""File wrangling.""" """File wrangling."""
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
import ntpath import ntpath
import os
import os.path import os.path
import posixpath import posixpath
import re import re
import sys import sys
from typing import Callable
from typing import Callable, Iterable from typing import Iterable
from coverage import env from coverage import env
from coverage.exceptions import ConfigError from coverage.exceptions import ConfigError
from coverage.misc import human_sorted, isolate_module, join_regex from coverage.misc import human_sorted
from coverage.misc import isolate_module
from coverage.misc import join_regex
os = isolate_module(os) os = isolate_module(os)
RELATIVE_DIR: str = "" RELATIVE_DIR: str = ''
CANONICAL_FILENAME_CACHE: dict[str, str] = {} CANONICAL_FILENAME_CACHE: dict[str, str] = {}
def set_relative_directory() -> None: def set_relative_directory() -> None:
"""Set the directory that `relative_filename` will be relative to.""" """Set the directory that `relative_filename` will be relative to."""
global RELATIVE_DIR, CANONICAL_FILENAME_CACHE global RELATIVE_DIR, CANONICAL_FILENAME_CACHE
@ -73,7 +73,7 @@ def canonical_filename(filename: str) -> str:
if not os.path.isabs(filename): if not os.path.isabs(filename):
for path in [os.curdir] + sys.path: for path in [os.curdir] + sys.path:
if path is None: if path is None:
continue # type: ignore[unreachable] continue # type: ignore[unreachable]
f = os.path.join(path, filename) f = os.path.join(path, filename)
try: try:
exists = os.path.exists(f) exists = os.path.exists(f)
@ -89,6 +89,7 @@ def canonical_filename(filename: str) -> str:
MAX_FLAT = 100 MAX_FLAT = 100
def flat_rootname(filename: str) -> str: def flat_rootname(filename: str) -> str:
"""A base for a flat file name to correspond to this file. """A base for a flat file name to correspond to this file.
@ -101,11 +102,11 @@ def flat_rootname(filename: str) -> str:
""" """
dirname, basename = ntpath.split(filename) dirname, basename = ntpath.split(filename)
if dirname: if dirname:
fp = hashlib.new("sha3_256", dirname.encode("UTF-8")).hexdigest()[:16] fp = hashlib.new('sha3_256', dirname.encode('UTF-8')).hexdigest()[:16]
prefix = f"d_{fp}_" prefix = f'd_{fp}_'
else: else:
prefix = "" prefix = ''
return prefix + basename.replace(".", "_") return prefix + basename.replace('.', '_')
if env.WINDOWS: if env.WINDOWS:
@ -163,7 +164,7 @@ def zip_location(filename: str) -> tuple[str, str] | None:
name is in the zipfile. name is in the zipfile.
""" """
for ext in [".zip", ".whl", ".egg", ".pex"]: for ext in ['.zip', '.whl', '.egg', '.pex']:
zipbase, extension, inner = filename.partition(ext + sep(filename)) zipbase, extension, inner = filename.partition(ext + sep(filename))
if extension: if extension:
zipfile = zipbase + ext zipfile = zipbase + ext
@ -210,7 +211,7 @@ def prep_patterns(patterns: Iterable[str]) -> list[str]:
prepped = [] prepped = []
for p in patterns or []: for p in patterns or []:
prepped.append(p) prepped.append(p)
if not p.startswith(("*", "?")): if not p.startswith(('*', '?')):
prepped.append(abs_file(p)) prepped.append(abs_file(p))
return prepped return prepped
@ -223,14 +224,15 @@ class TreeMatcher:
somewhere in a subtree rooted at one of the directories. somewhere in a subtree rooted at one of the directories.
""" """
def __init__(self, paths: Iterable[str], name: str = "unknown") -> None:
def __init__(self, paths: Iterable[str], name: str = 'unknown') -> None:
self.original_paths: list[str] = human_sorted(paths) self.original_paths: list[str] = human_sorted(paths)
#self.paths = list(map(os.path.normcase, paths)) #self.paths = list(map(os.path.normcase, paths))
self.paths = [os.path.normcase(p) for p in paths] self.paths = [os.path.normcase(p) for p in paths]
self.name = name self.name = name
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<TreeMatcher {self.name} {self.original_paths!r}>" return f'<TreeMatcher {self.name} {self.original_paths!r}>'
def info(self) -> list[str]: def info(self) -> list[str]:
"""A list of strings for displaying when dumping state.""" """A list of strings for displaying when dumping state."""
@ -252,12 +254,13 @@ class TreeMatcher:
class ModuleMatcher: class ModuleMatcher:
"""A matcher for modules in a tree.""" """A matcher for modules in a tree."""
def __init__(self, module_names: Iterable[str], name:str = "unknown") -> None:
def __init__(self, module_names: Iterable[str], name: str = 'unknown') -> None:
self.modules = list(module_names) self.modules = list(module_names)
self.name = name self.name = name
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<ModuleMatcher {self.name} {self.modules!r}>" return f'<ModuleMatcher {self.name} {self.modules!r}>'
def info(self) -> list[str]: def info(self) -> list[str]:
"""A list of strings for displaying when dumping state.""" """A list of strings for displaying when dumping state."""
@ -272,7 +275,7 @@ class ModuleMatcher:
if module_name.startswith(m): if module_name.startswith(m):
if module_name == m: if module_name == m:
return True return True
if module_name[len(m)] == ".": if module_name[len(m)] == '.':
# This is a module in the package # This is a module in the package
return True return True
@ -281,13 +284,14 @@ class ModuleMatcher:
class GlobMatcher: class GlobMatcher:
"""A matcher for files by file name pattern.""" """A matcher for files by file name pattern."""
def __init__(self, pats: Iterable[str], name: str = "unknown") -> None:
def __init__(self, pats: Iterable[str], name: str = 'unknown') -> None:
self.pats = list(pats) self.pats = list(pats)
self.re = globs_to_regex(self.pats, case_insensitive=env.WINDOWS) self.re = globs_to_regex(self.pats, case_insensitive=env.WINDOWS)
self.name = name self.name = name
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<GlobMatcher {self.name} {self.pats!r}>" return f'<GlobMatcher {self.name} {self.pats!r}>'
def info(self) -> list[str]: def info(self) -> list[str]:
"""A list of strings for displaying when dumping state.""" """A list of strings for displaying when dumping state."""
@ -300,7 +304,7 @@ class GlobMatcher:
def sep(s: str) -> str: def sep(s: str) -> str:
"""Find the path separator used in this string, or os.sep if none.""" """Find the path separator used in this string, or os.sep if none."""
if sep_match := re.search(r"[\\/]", s): if sep_match := re.search(r'[\\/]', s):
the_sep = sep_match[0] the_sep = sep_match[0]
else: else:
the_sep = os.sep the_sep = os.sep
@ -309,29 +313,32 @@ def sep(s: str) -> str:
# Tokenizer for _glob_to_regex. # Tokenizer for _glob_to_regex.
# None as a sub means disallowed. # None as a sub means disallowed.
G2RX_TOKENS = [(re.compile(rx), sub) for rx, sub in [ G2RX_TOKENS = [
(r"\*\*\*+", None), # Can't have *** (re.compile(rx), sub) for rx, sub in [
(r"[^/]+\*\*+", None), # Can't have x** (r'\*\*\*+', None), # Can't have ***
(r"\*\*+[^/]+", None), # Can't have **x (r'[^/]+\*\*+', None), # Can't have x**
(r"\*\*/\*\*", None), # Can't have **/** (r'\*\*+[^/]+', None), # Can't have **x
(r"^\*+/", r"(.*[/\\\\])?"), # ^*/ matches any prefix-slash, or nothing. (r'\*\*/\*\*', None), # Can't have **/**
(r"/\*+$", r"[/\\\\].*"), # /*$ matches any slash-suffix. (r'^\*+/', r'(.*[/\\\\])?'), # ^*/ matches any prefix-slash, or nothing.
(r"\*\*/", r"(.*[/\\\\])?"), # **/ matches any subdirs, including none (r'/\*+$', r'[/\\\\].*'), # /*$ matches any slash-suffix.
(r"/", r"[/\\\\]"), # / matches either slash or backslash (r'\*\*/', r'(.*[/\\\\])?'), # **/ matches any subdirs, including none
(r"\*", r"[^/\\\\]*"), # * matches any number of non slash-likes (r'/', r'[/\\\\]'), # / matches either slash or backslash
(r"\?", r"[^/\\\\]"), # ? matches one non slash-like (r'\*', r'[^/\\\\]*'), # * matches any number of non slash-likes
(r"\[.*?\]", r"\g<0>"), # [a-f] matches [a-f] (r'\?', r'[^/\\\\]'), # ? matches one non slash-like
(r"[a-zA-Z0-9_-]+", r"\g<0>"), # word chars match themselves (r'\[.*?\]', r'\g<0>'), # [a-f] matches [a-f]
(r"[\[\]]", None), # Can't have single square brackets (r'[a-zA-Z0-9_-]+', r'\g<0>'), # word chars match themselves
(r".", r"\\\g<0>"), # Anything else is escaped to be safe (r'[\[\]]', None), # Can't have single square brackets
]] (r'.', r'\\\g<0>'), # Anything else is escaped to be safe
]
]
def _glob_to_regex(pattern: str) -> str: def _glob_to_regex(pattern: str) -> str:
"""Convert a file-path glob pattern into a regex.""" """Convert a file-path glob pattern into a regex."""
# Turn all backslashes into slashes to simplify the tokenizer. # Turn all backslashes into slashes to simplify the tokenizer.
pattern = pattern.replace("\\", "/") pattern = pattern.replace('\\', '/')
if "/" not in pattern: if '/' not in pattern:
pattern = "**/" + pattern pattern = '**/' + pattern
path_rx = [] path_rx = []
pos = 0 pos = 0
while pos < len(pattern): while pos < len(pattern):
@ -342,7 +349,7 @@ def _glob_to_regex(pattern: str) -> str:
path_rx.append(m.expand(sub)) path_rx.append(m.expand(sub))
pos = m.end() pos = m.end()
break break
return "".join(path_rx) return ''.join(path_rx)
def globs_to_regex( def globs_to_regex(
@ -371,7 +378,7 @@ def globs_to_regex(
flags |= re.IGNORECASE flags |= re.IGNORECASE
rx = join_regex(map(_glob_to_regex, patterns)) rx = join_regex(map(_glob_to_regex, patterns))
if not partial: if not partial:
rx = fr"(?:{rx})\Z" rx = fr'(?:{rx})\Z'
compiled = re.compile(rx, flags=flags) compiled = re.compile(rx, flags=flags)
return compiled return compiled
@ -387,6 +394,7 @@ class PathAliases:
map a path through those aliases to produce a unified path. map a path through those aliases to produce a unified path.
""" """
def __init__( def __init__(
self, self,
debugfn: Callable[[str], None] | None = None, debugfn: Callable[[str], None] | None = None,
@ -400,9 +408,9 @@ class PathAliases:
def pprint(self) -> None: def pprint(self) -> None:
"""Dump the important parts of the PathAliases, for debugging.""" """Dump the important parts of the PathAliases, for debugging."""
self.debugfn(f"Aliases (relative={self.relative}):") self.debugfn(f'Aliases (relative={self.relative}):')
for original_pattern, regex, result in self.aliases: for original_pattern, regex, result in self.aliases:
self.debugfn(f" Rule: {original_pattern!r} -> {result!r} using regex {regex.pattern!r}") self.debugfn(f' Rule: {original_pattern!r} -> {result!r} using regex {regex.pattern!r}')
def add(self, pattern: str, result: str) -> None: def add(self, pattern: str, result: str) -> None:
"""Add the `pattern`/`result` pair to the list of aliases. """Add the `pattern`/`result` pair to the list of aliases.
@ -421,16 +429,16 @@ class PathAliases:
pattern_sep = sep(pattern) pattern_sep = sep(pattern)
if len(pattern) > 1: if len(pattern) > 1:
pattern = pattern.rstrip(r"\/") pattern = pattern.rstrip(r'\/')
# The pattern can't end with a wildcard component. # The pattern can't end with a wildcard component.
if pattern.endswith("*"): if pattern.endswith('*'):
raise ConfigError("Pattern must not end with wildcards.") raise ConfigError('Pattern must not end with wildcards.')
# The pattern is meant to match a file path. Let's make it absolute # The pattern is meant to match a file path. Let's make it absolute
# unless it already is, or is meant to match any prefix. # unless it already is, or is meant to match any prefix.
if not self.relative: if not self.relative:
if not pattern.startswith("*") and not isabs_anywhere(pattern + pattern_sep): if not pattern.startswith('*') and not isabs_anywhere(pattern + pattern_sep):
pattern = abs_file(pattern) pattern = abs_file(pattern)
if not pattern.endswith(pattern_sep): if not pattern.endswith(pattern_sep):
pattern += pattern_sep pattern += pattern_sep
@ -440,10 +448,10 @@ class PathAliases:
# Normalize the result: it must end with a path separator. # Normalize the result: it must end with a path separator.
result_sep = sep(result) result_sep = sep(result)
result = result.rstrip(r"\/") + result_sep result = result.rstrip(r'\/') + result_sep
self.aliases.append((original_pattern, regex, result)) self.aliases.append((original_pattern, regex, result))
def map(self, path: str, exists:Callable[[str], bool] = source_exists) -> str: def map(self, path: str, exists: Callable[[str], bool] = source_exists) -> str:
"""Map `path` through the aliases. """Map `path` through the aliases.
`path` is checked against all of the patterns. The first pattern to `path` is checked against all of the patterns. The first pattern to
@ -472,18 +480,18 @@ class PathAliases:
new = new.replace(sep(path), sep(result)) new = new.replace(sep(path), sep(result))
if not self.relative: if not self.relative:
new = canonical_filename(new) new = canonical_filename(new)
dot_start = result.startswith(("./", ".\\")) and len(result) > 2 dot_start = result.startswith(('./', '.\\')) and len(result) > 2
if new.startswith(("./", ".\\")) and not dot_start: if new.startswith(('./', '.\\')) and not dot_start:
new = new[2:] new = new[2:]
if not exists(new): if not exists(new):
self.debugfn( self.debugfn(
f"Rule {original_pattern!r} changed {path!r} to {new!r} " + f'Rule {original_pattern!r} changed {path!r} to {new!r} ' +
"which doesn't exist, continuing", "which doesn't exist, continuing",
) )
continue continue
self.debugfn( self.debugfn(
f"Matched path {path!r} to rule {original_pattern!r} -> {result!r}, " + f'Matched path {path!r} to rule {original_pattern!r} -> {result!r}, ' +
f"producing {new!r}", f'producing {new!r}',
) )
return new return new
@ -494,21 +502,21 @@ class PathAliases:
if self.relative and not isabs_anywhere(path): if self.relative and not isabs_anywhere(path):
# Auto-generate a pattern to implicitly match relative files # Auto-generate a pattern to implicitly match relative files
parts = re.split(r"[/\\]", path) parts = re.split(r'[/\\]', path)
if len(parts) > 1: if len(parts) > 1:
dir1 = parts[0] dir1 = parts[0]
pattern = f"*/{dir1}" pattern = f'*/{dir1}'
regex_pat = fr"^(.*[\\/])?{re.escape(dir1)}[\\/]" regex_pat = fr'^(.*[\\/])?{re.escape(dir1)}[\\/]'
result = f"{dir1}{os.sep}" result = f'{dir1}{os.sep}'
# Only add a new pattern if we don't already have this pattern. # Only add a new pattern if we don't already have this pattern.
if not any(p == pattern for p, _, _ in self.aliases): if not any(p == pattern for p, _, _ in self.aliases):
self.debugfn( self.debugfn(
f"Generating rule: {pattern!r} -> {result!r} using regex {regex_pat!r}", f'Generating rule: {pattern!r} -> {result!r} using regex {regex_pat!r}',
) )
self.aliases.append((pattern, re.compile(regex_pat), result)) self.aliases.append((pattern, re.compile(regex_pat), result))
return self.map(path, exists=exists) return self.map(path, exists=exists)
self.debugfn(f"No rules match, path {path!r} is unchanged") self.debugfn(f'No rules match, path {path!r} is unchanged')
return path return path
@ -530,7 +538,7 @@ def find_python_files(dirname: str, include_namespace_packages: bool) -> Iterabl
""" """
for i, (dirpath, dirnames, filenames) in enumerate(os.walk(dirname)): for i, (dirpath, dirnames, filenames) in enumerate(os.walk(dirname)):
if not include_namespace_packages: if not include_namespace_packages:
if i > 0 and "__init__.py" not in filenames: if i > 0 and '__init__.py' not in filenames:
# If a directory doesn't have __init__.py, then it isn't # If a directory doesn't have __init__.py, then it isn't
# importable and neither are its files # importable and neither are its files
del dirnames[:] del dirnames[:]
@ -539,7 +547,7 @@ def find_python_files(dirname: str, include_namespace_packages: bool) -> Iterabl
# We're only interested in files that look like reasonable Python # We're only interested in files that look like reasonable Python
# files: Must end with .py or .pyw, and must not have certain funny # files: Must end with .py or .pyw, and must not have certain funny
# characters that probably mean they are editor junk. # characters that probably mean they are editor junk.
if re.match(r"^[^.#~!$@%^&*()+=,]+\.pyw?$", filename): if re.match(r'^[^.#~!$@%^&*()+=,]+\.pyw?$', filename):
yield os.path.join(dirpath, filename) yield os.path.join(dirpath, filename)

View file

@ -1,8 +1,6 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""HTML reporting for coverage.py.""" """HTML reporting for coverage.py."""
from __future__ import annotations from __future__ import annotations
import collections import collections
@ -13,20 +11,31 @@ import os
import re import re
import shutil import shutil
import string import string
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Iterable, TYPE_CHECKING, cast from typing import Any
from typing import cast
from typing import Iterable
from typing import TYPE_CHECKING
import coverage import coverage
from coverage.data import CoverageData, add_data_to_hash from coverage.data import add_data_to_hash
from coverage.data import CoverageData
from coverage.exceptions import NoDataError from coverage.exceptions import NoDataError
from coverage.files import flat_rootname from coverage.files import flat_rootname
from coverage.misc import ensure_dir, file_be_gone, Hasher, isolate_module, format_local_datetime from coverage.misc import ensure_dir
from coverage.misc import human_sorted, plural, stdout_link from coverage.misc import file_be_gone
from coverage.misc import format_local_datetime
from coverage.misc import Hasher
from coverage.misc import human_sorted
from coverage.misc import isolate_module
from coverage.misc import plural
from coverage.misc import stdout_link
from coverage.report_core import get_analysis_to_report from coverage.report_core import get_analysis_to_report
from coverage.results import Analysis, Numbers from coverage.results import Analysis
from coverage.results import Numbers
from coverage.templite import Templite from coverage.templite import Templite
from coverage.types import TLineNo, TMorf from coverage.types import TLineNo
from coverage.types import TMorf
from coverage.version import __url__ from coverage.version import __url__
@ -56,7 +65,7 @@ os = isolate_module(os)
def data_filename(fname: str) -> str: def data_filename(fname: str) -> str:
"""Return the path to an "htmlfiles" data file of ours. """Return the path to an "htmlfiles" data file of ours.
""" """
static_dir = os.path.join(os.path.dirname(__file__), "htmlfiles") static_dir = os.path.join(os.path.dirname(__file__), 'htmlfiles')
static_filename = os.path.join(static_dir, fname) static_filename = os.path.join(static_dir, fname)
return static_filename return static_filename
@ -69,9 +78,9 @@ def read_data(fname: str) -> str:
def write_html(fname: str, html: str) -> None: def write_html(fname: str, html: str) -> None:
"""Write `html` to `fname`, properly encoded.""" """Write `html` to `fname`, properly encoded."""
html = re.sub(r"(\A\s+)|(\s+$)", "", html, flags=re.MULTILINE) + "\n" html = re.sub(r'(\A\s+)|(\s+$)', '', html, flags=re.MULTILINE) + '\n'
with open(fname, "wb") as fout: with open(fname, 'wb') as fout:
fout.write(html.encode("ascii", "xmlcharrefreplace")) fout.write(html.encode('ascii', 'xmlcharrefreplace'))
@dataclass @dataclass
@ -86,11 +95,11 @@ class LineData:
context_list: list[str] context_list: list[str]
short_annotations: list[str] short_annotations: list[str]
long_annotations: list[str] long_annotations: list[str]
html: str = "" html: str = ''
context_str: str | None = None context_str: str | None = None
annotate: str | None = None annotate: str | None = None
annotate_long: str | None = None annotate_long: str | None = None
css_class: str = "" css_class: str = ''
@dataclass @dataclass
@ -104,7 +113,7 @@ class FileData:
class HtmlDataGeneration: class HtmlDataGeneration:
"""Generate structured data to be turned into HTML reports.""" """Generate structured data to be turned into HTML reports."""
EMPTY = "(empty)" EMPTY = '(empty)'
def __init__(self, cov: Coverage) -> None: def __init__(self, cov: Coverage) -> None:
self.coverage = cov self.coverage = cov
@ -112,8 +121,8 @@ class HtmlDataGeneration:
data = self.coverage.get_data() data = self.coverage.get_data()
self.has_arcs = data.has_arcs() self.has_arcs = data.has_arcs()
if self.config.show_contexts: if self.config.show_contexts:
if data.measured_contexts() == {""}: if data.measured_contexts() == {''}:
self.coverage._warn("No contexts were measured") self.coverage._warn('No contexts were measured')
data.set_query_contexts(self.config.report_contexts) data.set_query_contexts(self.config.report_contexts)
def data_for_file(self, fr: FileReporter, analysis: Analysis) -> FileData: def data_for_file(self, fr: FileReporter, analysis: Analysis) -> FileData:
@ -129,47 +138,49 @@ class HtmlDataGeneration:
for lineno, tokens in enumerate(fr.source_token_lines(), start=1): for lineno, tokens in enumerate(fr.source_token_lines(), start=1):
# Figure out how to mark this line. # Figure out how to mark this line.
category = "" category = ''
short_annotations = [] short_annotations = []
long_annotations = [] long_annotations = []
if lineno in analysis.excluded: if lineno in analysis.excluded:
category = "exc" category = 'exc'
elif lineno in analysis.missing: elif lineno in analysis.missing:
category = "mis" category = 'mis'
elif self.has_arcs and lineno in missing_branch_arcs: elif self.has_arcs and lineno in missing_branch_arcs:
category = "par" category = 'par'
for b in missing_branch_arcs[lineno]: for b in missing_branch_arcs[lineno]:
if b < 0: if b < 0:
short_annotations.append("exit") short_annotations.append('exit')
else: else:
short_annotations.append(str(b)) short_annotations.append(str(b))
long_annotations.append(fr.missing_arc_description(lineno, b, arcs_executed)) long_annotations.append(fr.missing_arc_description(lineno, b, arcs_executed))
elif lineno in analysis.statements: elif lineno in analysis.statements:
category = "run" category = 'run'
contexts = [] contexts = []
contexts_label = "" contexts_label = ''
context_list = [] context_list = []
if category and self.config.show_contexts: if category and self.config.show_contexts:
contexts = human_sorted(c or self.EMPTY for c in contexts_by_lineno.get(lineno, ())) contexts = human_sorted(c or self.EMPTY for c in contexts_by_lineno.get(lineno, ()))
if contexts == [self.EMPTY]: if contexts == [self.EMPTY]:
contexts_label = self.EMPTY contexts_label = self.EMPTY
else: else:
contexts_label = f"{len(contexts)} ctx" contexts_label = f'{len(contexts)} ctx'
context_list = contexts context_list = contexts
lines.append(LineData( lines.append(
tokens=tokens, LineData(
number=lineno, tokens=tokens,
category=category, number=lineno,
statement=(lineno in analysis.statements), category=category,
contexts=contexts, statement=(lineno in analysis.statements),
contexts_label=contexts_label, contexts=contexts,
context_list=context_list, contexts_label=contexts_label,
short_annotations=short_annotations, context_list=context_list,
long_annotations=long_annotations, short_annotations=short_annotations,
)) long_annotations=long_annotations,
),
)
file_data = FileData( file_data = FileData(
relative_filename=fr.relative_filename(), relative_filename=fr.relative_filename(),
@ -182,15 +193,17 @@ class HtmlDataGeneration:
class FileToReport: class FileToReport:
"""A file we're considering reporting.""" """A file we're considering reporting."""
def __init__(self, fr: FileReporter, analysis: Analysis) -> None: def __init__(self, fr: FileReporter, analysis: Analysis) -> None:
self.fr = fr self.fr = fr
self.analysis = analysis self.analysis = analysis
self.rootname = flat_rootname(fr.relative_filename()) self.rootname = flat_rootname(fr.relative_filename())
self.html_filename = self.rootname + ".html" self.html_filename = self.rootname + '.html'
HTML_SAFE = string.ascii_letters + string.digits + "!#$%'()*+,-./:;=?@[]^_`{|}~" HTML_SAFE = string.ascii_letters + string.digits + "!#$%'()*+,-./:;=?@[]^_`{|}~"
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def encode_int(n: int) -> str: def encode_int(n: int) -> str:
"""Create a short HTML-safe string from an integer, using HTML_SAFE.""" """Create a short HTML-safe string from an integer, using HTML_SAFE."""
@ -201,7 +214,7 @@ def encode_int(n: int) -> str:
while n: while n:
n, t = divmod(n, len(HTML_SAFE)) n, t = divmod(n, len(HTML_SAFE))
r.append(HTML_SAFE[t]) r.append(HTML_SAFE[t])
return "".join(r) return ''.join(r)
class HtmlReporter: class HtmlReporter:
@ -210,11 +223,11 @@ class HtmlReporter:
# These files will be copied from the htmlfiles directory to the output # These files will be copied from the htmlfiles directory to the output
# directory. # directory.
STATIC_FILES = [ STATIC_FILES = [
"style.css", 'style.css',
"coverage_html.js", 'coverage_html.js',
"keybd_closed.png", 'keybd_closed.png',
"keybd_open.png", 'keybd_open.png',
"favicon_32.png", 'favicon_32.png',
] ]
def __init__(self, cov: Coverage) -> None: def __init__(self, cov: Coverage) -> None:
@ -253,29 +266,29 @@ class HtmlReporter:
self.template_globals = { self.template_globals = {
# Functions available in the templates. # Functions available in the templates.
"escape": escape, 'escape': escape,
"pair": pair, 'pair': pair,
"len": len, 'len': len,
# Constants for this report. # Constants for this report.
"__url__": __url__, '__url__': __url__,
"__version__": coverage.__version__, '__version__': coverage.__version__,
"title": title, 'title': title,
"time_stamp": format_local_datetime(datetime.datetime.now()), 'time_stamp': format_local_datetime(datetime.datetime.now()),
"extra_css": self.extra_css, 'extra_css': self.extra_css,
"has_arcs": self.has_arcs, 'has_arcs': self.has_arcs,
"show_contexts": self.config.show_contexts, 'show_contexts': self.config.show_contexts,
# Constants for all reports. # Constants for all reports.
# These css classes determine which lines are highlighted by default. # These css classes determine which lines are highlighted by default.
"category": { 'category': {
"exc": "exc show_exc", 'exc': 'exc show_exc',
"mis": "mis show_mis", 'mis': 'mis show_mis',
"par": "par run show_par", 'par': 'par run show_par',
"run": "run", 'run': 'run',
}, },
} }
self.pyfile_html_source = read_data("pyfile.html") self.pyfile_html_source = read_data('pyfile.html')
self.source_tmpl = Templite(self.pyfile_html_source, self.template_globals) self.source_tmpl = Templite(self.pyfile_html_source, self.template_globals)
def report(self, morfs: Iterable[TMorf] | None) -> float: def report(self, morfs: Iterable[TMorf] | None) -> float:
@ -303,17 +316,17 @@ class HtmlReporter:
for i, ftr in enumerate(files_to_report): for i, ftr in enumerate(files_to_report):
if i == 0: if i == 0:
prev_html = "index.html" prev_html = 'index.html'
else: else:
prev_html = files_to_report[i - 1].html_filename prev_html = files_to_report[i - 1].html_filename
if i == len(files_to_report) - 1: if i == len(files_to_report) - 1:
next_html = "index.html" next_html = 'index.html'
else: else:
next_html = files_to_report[i + 1].html_filename next_html = files_to_report[i + 1].html_filename
self.write_html_file(ftr, prev_html, next_html) self.write_html_file(ftr, prev_html, next_html)
if not self.all_files_nums: if not self.all_files_nums:
raise NoDataError("No data to report.") raise NoDataError('No data to report.')
self.totals = cast(Numbers, sum(self.all_files_nums)) self.totals = cast(Numbers, sum(self.all_files_nums))
@ -322,7 +335,7 @@ class HtmlReporter:
first_html = files_to_report[0].html_filename first_html = files_to_report[0].html_filename
final_html = files_to_report[-1].html_filename final_html = files_to_report[-1].html_filename
else: else:
first_html = final_html = "index.html" first_html = final_html = 'index.html'
self.index_file(first_html, final_html) self.index_file(first_html, final_html)
self.make_local_static_report_files() self.make_local_static_report_files()
@ -344,8 +357,8 @@ class HtmlReporter:
# .gitignore can't be copied from the source tree because it would # .gitignore can't be copied from the source tree because it would
# prevent the static files from being checked in. # prevent the static files from being checked in.
if self.directory_was_empty: if self.directory_was_empty:
with open(os.path.join(self.directory, ".gitignore"), "w") as fgi: with open(os.path.join(self.directory, '.gitignore'), 'w') as fgi:
fgi.write("# Created by coverage.py\n*\n") fgi.write('# Created by coverage.py\n*\n')
# The user may have extra CSS they want copied. # The user may have extra CSS they want copied.
if self.extra_css: if self.extra_css:
@ -401,29 +414,29 @@ class HtmlReporter:
# Build the HTML for the line. # Build the HTML for the line.
html_parts = [] html_parts = []
for tok_type, tok_text in ldata.tokens: for tok_type, tok_text in ldata.tokens:
if tok_type == "ws": if tok_type == 'ws':
html_parts.append(escape(tok_text)) html_parts.append(escape(tok_text))
else: else:
tok_html = escape(tok_text) or "&nbsp;" tok_html = escape(tok_text) or '&nbsp;'
html_parts.append(f'<span class="{tok_type}">{tok_html}</span>') html_parts.append(f'<span class="{tok_type}">{tok_html}</span>')
ldata.html = "".join(html_parts) ldata.html = ''.join(html_parts)
if ldata.context_list: if ldata.context_list:
encoded_contexts = [ encoded_contexts = [
encode_int(context_codes[c_context]) for c_context in ldata.context_list encode_int(context_codes[c_context]) for c_context in ldata.context_list
] ]
code_width = max(len(ec) for ec in encoded_contexts) code_width = max(len(ec) for ec in encoded_contexts)
ldata.context_str = ( ldata.context_str = (
str(code_width) str(code_width) +
+ "".join(ec.ljust(code_width) for ec in encoded_contexts) ''.join(ec.ljust(code_width) for ec in encoded_contexts)
) )
else: else:
ldata.context_str = "" ldata.context_str = ''
if ldata.short_annotations: if ldata.short_annotations:
# 202F is NARROW NO-BREAK SPACE. # 202F is NARROW NO-BREAK SPACE.
# 219B is RIGHTWARDS ARROW WITH STROKE. # 219B is RIGHTWARDS ARROW WITH STROKE.
ldata.annotate = ",&nbsp;&nbsp; ".join( ldata.annotate = ',&nbsp;&nbsp; '.join(
f"{ldata.number}&#x202F;&#x219B;&#x202F;{d}" f'{ldata.number}&#x202F;&#x219B;&#x202F;{d}'
for d in ldata.short_annotations for d in ldata.short_annotations
) )
else: else:
@ -434,10 +447,10 @@ class HtmlReporter:
if len(longs) == 1: if len(longs) == 1:
ldata.annotate_long = longs[0] ldata.annotate_long = longs[0]
else: else:
ldata.annotate_long = "{:d} missed branches: {}".format( ldata.annotate_long = '{:d} missed branches: {}'.format(
len(longs), len(longs),
", ".join( ', '.join(
f"{num:d}) {ann_long}" f'{num:d}) {ann_long}'
for num, ann_long in enumerate(longs, start=1) for num, ann_long in enumerate(longs, start=1)
), ),
) )
@ -447,24 +460,24 @@ class HtmlReporter:
css_classes = [] css_classes = []
if ldata.category: if ldata.category:
css_classes.append( css_classes.append(
self.template_globals["category"][ldata.category], # type: ignore[index] self.template_globals['category'][ldata.category], # type: ignore[index]
) )
ldata.css_class = " ".join(css_classes) or "pln" ldata.css_class = ' '.join(css_classes) or 'pln'
html_path = os.path.join(self.directory, ftr.html_filename) html_path = os.path.join(self.directory, ftr.html_filename)
html = self.source_tmpl.render({ html = self.source_tmpl.render({
**file_data.__dict__, **file_data.__dict__,
"contexts_json": contexts_json, 'contexts_json': contexts_json,
"prev_html": prev_html, 'prev_html': prev_html,
"next_html": next_html, 'next_html': next_html,
}) })
write_html(html_path, html) write_html(html_path, html)
# Save this file's information for the index file. # Save this file's information for the index file.
index_info: IndexInfoDict = { index_info: IndexInfoDict = {
"nums": ftr.analysis.numbers, 'nums': ftr.analysis.numbers,
"html_filename": ftr.html_filename, 'html_filename': ftr.html_filename,
"relative_filename": ftr.fr.relative_filename(), 'relative_filename': ftr.fr.relative_filename(),
} }
self.file_summaries.append(index_info) self.file_summaries.append(index_info)
self.incr.set_index_info(ftr.rootname, index_info) self.incr.set_index_info(ftr.rootname, index_info)
@ -472,30 +485,30 @@ class HtmlReporter:
def index_file(self, first_html: str, final_html: str) -> None: def index_file(self, first_html: str, final_html: str) -> None:
"""Write the index.html file for this report.""" """Write the index.html file for this report."""
self.make_directory() self.make_directory()
index_tmpl = Templite(read_data("index.html"), self.template_globals) index_tmpl = Templite(read_data('index.html'), self.template_globals)
skipped_covered_msg = skipped_empty_msg = "" skipped_covered_msg = skipped_empty_msg = ''
if self.skipped_covered_count: if self.skipped_covered_count:
n = self.skipped_covered_count n = self.skipped_covered_count
skipped_covered_msg = f"{n} file{plural(n)} skipped due to complete coverage." skipped_covered_msg = f'{n} file{plural(n)} skipped due to complete coverage.'
if self.skipped_empty_count: if self.skipped_empty_count:
n = self.skipped_empty_count n = self.skipped_empty_count
skipped_empty_msg = f"{n} empty file{plural(n)} skipped." skipped_empty_msg = f'{n} empty file{plural(n)} skipped.'
html = index_tmpl.render({ html = index_tmpl.render({
"files": self.file_summaries, 'files': self.file_summaries,
"totals": self.totals, 'totals': self.totals,
"skipped_covered_msg": skipped_covered_msg, 'skipped_covered_msg': skipped_covered_msg,
"skipped_empty_msg": skipped_empty_msg, 'skipped_empty_msg': skipped_empty_msg,
"first_html": first_html, 'first_html': first_html,
"final_html": final_html, 'final_html': final_html,
}) })
index_file = os.path.join(self.directory, "index.html") index_file = os.path.join(self.directory, 'index.html')
write_html(index_file, html) write_html(index_file, html)
print_href = stdout_link(index_file, f"file://{os.path.abspath(index_file)}") print_href = stdout_link(index_file, f'file://{os.path.abspath(index_file)}')
self.coverage._message(f"Wrote HTML report to {print_href}") self.coverage._message(f'Wrote HTML report to {print_href}')
# Write the latest hashes for next time. # Write the latest hashes for next time.
self.incr.write() self.incr.write()
@ -504,12 +517,12 @@ class HtmlReporter:
class IncrementalChecker: class IncrementalChecker:
"""Logic and data to support incremental reporting.""" """Logic and data to support incremental reporting."""
STATUS_FILE = "status.json" STATUS_FILE = 'status.json'
STATUS_FORMAT = 2 STATUS_FORMAT = 2
NOTE = ( NOTE = (
"This file is an internal implementation detail to speed up HTML report" 'This file is an internal implementation detail to speed up HTML report' +
+ " generation. Its format can change at any time. You might be looking" ' generation. Its format can change at any time. You might be looking' +
+ " for the JSON report: https://coverage.rtfd.io/cmd.html#cmd-json" ' for the JSON report: https://coverage.rtfd.io/cmd.html#cmd-json'
) )
# The data looks like: # The data looks like:
@ -545,7 +558,7 @@ class IncrementalChecker:
def reset(self) -> None: def reset(self) -> None:
"""Initialize to empty. Causes all files to be reported.""" """Initialize to empty. Causes all files to be reported."""
self.globals = "" self.globals = ''
self.files: dict[str, FileInfoDict] = {} self.files: dict[str, FileInfoDict] = {}
def read(self) -> None: def read(self) -> None:
@ -559,17 +572,17 @@ class IncrementalChecker:
usable = False usable = False
else: else:
usable = True usable = True
if status["format"] != self.STATUS_FORMAT: if status['format'] != self.STATUS_FORMAT:
usable = False usable = False
elif status["version"] != coverage.__version__: elif status['version'] != coverage.__version__:
usable = False usable = False
if usable: if usable:
self.files = {} self.files = {}
for filename, fileinfo in status["files"].items(): for filename, fileinfo in status['files'].items():
fileinfo["index"]["nums"] = Numbers(*fileinfo["index"]["nums"]) fileinfo['index']['nums'] = Numbers(*fileinfo['index']['nums'])
self.files[filename] = fileinfo self.files[filename] = fileinfo
self.globals = status["globals"] self.globals = status['globals']
else: else:
self.reset() self.reset()
@ -578,19 +591,19 @@ class IncrementalChecker:
status_file = os.path.join(self.directory, self.STATUS_FILE) status_file = os.path.join(self.directory, self.STATUS_FILE)
files = {} files = {}
for filename, fileinfo in self.files.items(): for filename, fileinfo in self.files.items():
index = fileinfo["index"] index = fileinfo['index']
index["nums"] = index["nums"].init_args() # type: ignore[typeddict-item] index['nums'] = index['nums'].init_args() # type: ignore[typeddict-item]
files[filename] = fileinfo files[filename] = fileinfo
status = { status = {
"note": self.NOTE, 'note': self.NOTE,
"format": self.STATUS_FORMAT, 'format': self.STATUS_FORMAT,
"version": coverage.__version__, 'version': coverage.__version__,
"globals": self.globals, 'globals': self.globals,
"files": files, 'files': files,
} }
with open(status_file, "w") as fout: with open(status_file, 'w') as fout:
json.dump(status, fout, separators=(",", ":")) json.dump(status, fout, separators=(',', ':'))
def check_global_data(self, *data: Any) -> None: def check_global_data(self, *data: Any) -> None:
"""Check the global data that can affect incremental reporting.""" """Check the global data that can affect incremental reporting."""
@ -609,7 +622,7 @@ class IncrementalChecker:
`rootname` is the name being used for the file. `rootname` is the name being used for the file.
""" """
m = Hasher() m = Hasher()
m.update(fr.source().encode("utf-8")) m.update(fr.source().encode('utf-8'))
add_data_to_hash(data, fr.filename, m) add_data_to_hash(data, fr.filename, m)
this_hash = m.hexdigest() this_hash = m.hexdigest()
@ -624,19 +637,19 @@ class IncrementalChecker:
def file_hash(self, fname: str) -> str: def file_hash(self, fname: str) -> str:
"""Get the hash of `fname`'s contents.""" """Get the hash of `fname`'s contents."""
return self.files.get(fname, {}).get("hash", "") # type: ignore[call-overload] return self.files.get(fname, {}).get('hash', '') # type: ignore[call-overload]
def set_file_hash(self, fname: str, val: str) -> None: def set_file_hash(self, fname: str, val: str) -> None:
"""Set the hash of `fname`'s contents.""" """Set the hash of `fname`'s contents."""
self.files.setdefault(fname, {})["hash"] = val # type: ignore[typeddict-item] self.files.setdefault(fname, {})['hash'] = val # type: ignore[typeddict-item]
def index_info(self, fname: str) -> IndexInfoDict: def index_info(self, fname: str) -> IndexInfoDict:
"""Get the information for index.html for `fname`.""" """Get the information for index.html for `fname`."""
return self.files.get(fname, {}).get("index", {}) # type: ignore return self.files.get(fname, {}).get('index', {}) # type: ignore
def set_index_info(self, fname: str, info: IndexInfoDict) -> None: def set_index_info(self, fname: str, info: IndexInfoDict) -> None:
"""Set the information for index.html for `fname`.""" """Set the information for index.html for `fname`."""
self.files.setdefault(fname, {})["index"] = info # type: ignore[typeddict-item] self.files.setdefault(fname, {})['index'] = info # type: ignore[typeddict-item]
# Helpers for templates and generating HTML # Helpers for templates and generating HTML
@ -648,9 +661,9 @@ def escape(t: str) -> str:
""" """
# Convert HTML special chars into HTML entities. # Convert HTML special chars into HTML entities.
return t.replace("&", "&amp;").replace("<", "&lt;") return t.replace('&', '&amp;').replace('<', '&lt;')
def pair(ratio: tuple[int, int]) -> str: def pair(ratio: tuple[int, int]) -> str:
"""Format a pair of numbers so JavaScript can read them in an attribute.""" """Format a pair of numbers so JavaScript can read them in an attribute."""
return "{} {}".format(*ratio) return '{} {}'.format(*ratio)

View file

@ -1,8 +1,6 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Determining whether files are being measured/reported or not.""" """Determining whether files are being measured/reported or not."""
from __future__ import annotations from __future__ import annotations
import importlib.util import importlib.util
@ -14,20 +12,31 @@ import re
import sys import sys
import sysconfig import sysconfig
import traceback import traceback
from types import FrameType
from types import FrameType, ModuleType from types import ModuleType
from typing import ( from typing import Any
cast, Any, Iterable, TYPE_CHECKING, from typing import cast
) from typing import Iterable
from typing import TYPE_CHECKING
from coverage import env from coverage import env
from coverage.disposition import FileDisposition, disposition_init from coverage.disposition import disposition_init
from coverage.exceptions import CoverageException, PluginError from coverage.disposition import FileDisposition
from coverage.files import TreeMatcher, GlobMatcher, ModuleMatcher from coverage.exceptions import CoverageException
from coverage.files import prep_patterns, find_python_files, canonical_filename from coverage.exceptions import PluginError
from coverage.files import canonical_filename
from coverage.files import find_python_files
from coverage.files import GlobMatcher
from coverage.files import ModuleMatcher
from coverage.files import prep_patterns
from coverage.files import TreeMatcher
from coverage.misc import sys_modules_saved from coverage.misc import sys_modules_saved
from coverage.python import source_for_file, source_for_morf from coverage.python import source_for_file
from coverage.types import TFileDisposition, TMorf, TWarnFn, TDebugCtl from coverage.python import source_for_morf
from coverage.types import TDebugCtl
from coverage.types import TFileDisposition
from coverage.types import TMorf
from coverage.types import TWarnFn
if TYPE_CHECKING: if TYPE_CHECKING:
from coverage.config import CoverageConfig from coverage.config import CoverageConfig
@ -65,7 +74,7 @@ def canonical_path(morf: TMorf, directory: bool = False) -> str:
""" """
morf_path = canonical_filename(source_for_morf(morf)) morf_path = canonical_filename(source_for_morf(morf))
if morf_path.endswith("__init__.py") or directory: if morf_path.endswith('__init__.py') or directory:
morf_path = os.path.split(morf_path)[0] morf_path = os.path.split(morf_path)[0]
return morf_path return morf_path
@ -83,16 +92,16 @@ def name_for_module(filename: str, frame: FrameType | None) -> str:
""" """
module_globals = frame.f_globals if frame is not None else {} module_globals = frame.f_globals if frame is not None else {}
dunder_name: str = module_globals.get("__name__", None) dunder_name: str = module_globals.get('__name__', None)
if isinstance(dunder_name, str) and dunder_name != "__main__": if isinstance(dunder_name, str) and dunder_name != '__main__':
# This is the usual case: an imported module. # This is the usual case: an imported module.
return dunder_name return dunder_name
spec = module_globals.get("__spec__", None) spec = module_globals.get('__spec__', None)
if spec: if spec:
fullname = spec.name fullname = spec.name
if isinstance(fullname, str) and fullname != "__main__": if isinstance(fullname, str) and fullname != '__main__':
# Module loaded via: runpy -m # Module loaded via: runpy -m
return fullname return fullname
@ -106,12 +115,12 @@ def name_for_module(filename: str, frame: FrameType | None) -> str:
def module_is_namespace(mod: ModuleType) -> bool: def module_is_namespace(mod: ModuleType) -> bool:
"""Is the module object `mod` a PEP420 namespace module?""" """Is the module object `mod` a PEP420 namespace module?"""
return hasattr(mod, "__path__") and getattr(mod, "__file__", None) is None return hasattr(mod, '__path__') and getattr(mod, '__file__', None) is None
def module_has_file(mod: ModuleType) -> bool: def module_has_file(mod: ModuleType) -> bool:
"""Does the module object `mod` have an existing __file__ ?""" """Does the module object `mod` have an existing __file__ ?"""
mod__file__ = getattr(mod, "__file__", None) mod__file__ = getattr(mod, '__file__', None)
if mod__file__ is None: if mod__file__ is None:
return False return False
return os.path.exists(mod__file__) return os.path.exists(mod__file__)
@ -146,7 +155,7 @@ def add_stdlib_paths(paths: set[str]) -> None:
# spread across a few locations. Look at all the candidate modules # spread across a few locations. Look at all the candidate modules
# we've imported, and take all the different ones. # we've imported, and take all the different ones.
for m in modules_we_happen_to_have: for m in modules_we_happen_to_have:
if hasattr(m, "__file__"): if hasattr(m, '__file__'):
paths.add(canonical_path(m, directory=True)) paths.add(canonical_path(m, directory=True))
@ -157,10 +166,10 @@ def add_third_party_paths(paths: set[str]) -> None:
for scheme in scheme_names: for scheme in scheme_names:
# https://foss.heptapod.net/pypy/pypy/-/issues/3433 # https://foss.heptapod.net/pypy/pypy/-/issues/3433
better_scheme = "pypy_posix" if scheme == "pypy" else scheme better_scheme = 'pypy_posix' if scheme == 'pypy' else scheme
if os.name in better_scheme.split("_"): if os.name in better_scheme.split('_'):
config_paths = sysconfig.get_paths(scheme) config_paths = sysconfig.get_paths(scheme)
for path_name in ["platlib", "purelib", "scripts"]: for path_name in ['platlib', 'purelib', 'scripts']:
paths.add(config_paths[path_name]) paths.add(config_paths[path_name])
@ -170,7 +179,7 @@ def add_coverage_paths(paths: set[str]) -> None:
paths.add(cover_path) paths.add(cover_path)
if env.TESTING: if env.TESTING:
# Don't include our own test code. # Don't include our own test code.
paths.add(os.path.join(cover_path, "tests")) paths.add(os.path.join(cover_path, 'tests'))
class InOrOut: class InOrOut:
@ -221,7 +230,7 @@ class InOrOut:
# The matchers for should_trace. # The matchers for should_trace.
# Generally useful information # Generally useful information
_debug("sys.path:" + "".join(f"\n {p}" for p in sys.path)) _debug('sys.path:' + ''.join(f'\n {p}' for p in sys.path))
# Create the matchers we need for should_trace # Create the matchers we need for should_trace
self.source_match = None self.source_match = None
@ -232,28 +241,28 @@ class InOrOut:
if self.source or self.source_pkgs: if self.source or self.source_pkgs:
against = [] against = []
if self.source: if self.source:
self.source_match = TreeMatcher(self.source, "source") self.source_match = TreeMatcher(self.source, 'source')
against.append(f"trees {self.source_match!r}") against.append(f'trees {self.source_match!r}')
if self.source_pkgs: if self.source_pkgs:
self.source_pkgs_match = ModuleMatcher(self.source_pkgs, "source_pkgs") self.source_pkgs_match = ModuleMatcher(self.source_pkgs, 'source_pkgs')
against.append(f"modules {self.source_pkgs_match!r}") against.append(f'modules {self.source_pkgs_match!r}')
_debug("Source matching against " + " and ".join(against)) _debug('Source matching against ' + ' and '.join(against))
else: else:
if self.pylib_paths: if self.pylib_paths:
self.pylib_match = TreeMatcher(self.pylib_paths, "pylib") self.pylib_match = TreeMatcher(self.pylib_paths, 'pylib')
_debug(f"Python stdlib matching: {self.pylib_match!r}") _debug(f'Python stdlib matching: {self.pylib_match!r}')
if self.include: if self.include:
self.include_match = GlobMatcher(self.include, "include") self.include_match = GlobMatcher(self.include, 'include')
_debug(f"Include matching: {self.include_match!r}") _debug(f'Include matching: {self.include_match!r}')
if self.omit: if self.omit:
self.omit_match = GlobMatcher(self.omit, "omit") self.omit_match = GlobMatcher(self.omit, 'omit')
_debug(f"Omit matching: {self.omit_match!r}") _debug(f'Omit matching: {self.omit_match!r}')
self.cover_match = TreeMatcher(self.cover_paths, "coverage") self.cover_match = TreeMatcher(self.cover_paths, 'coverage')
_debug(f"Coverage code matching: {self.cover_match!r}") _debug(f'Coverage code matching: {self.cover_match!r}')
self.third_match = TreeMatcher(self.third_paths, "third") self.third_match = TreeMatcher(self.third_paths, 'third')
_debug(f"Third-party lib matching: {self.third_match!r}") _debug(f'Third-party lib matching: {self.third_match!r}')
# Check if the source we want to measure has been installed as a # Check if the source we want to measure has been installed as a
# third-party package. # third-party package.
@ -263,30 +272,30 @@ class InOrOut:
for pkg in self.source_pkgs: for pkg in self.source_pkgs:
try: try:
modfile, path = file_and_path_for_module(pkg) modfile, path = file_and_path_for_module(pkg)
_debug(f"Imported source package {pkg!r} as {modfile!r}") _debug(f'Imported source package {pkg!r} as {modfile!r}')
except CoverageException as exc: except CoverageException as exc:
_debug(f"Couldn't import source package {pkg!r}: {exc}") _debug(f"Couldn't import source package {pkg!r}: {exc}")
continue continue
if modfile: if modfile:
if self.third_match.match(modfile): if self.third_match.match(modfile):
_debug( _debug(
f"Source in third-party: source_pkg {pkg!r} at {modfile!r}", f'Source in third-party: source_pkg {pkg!r} at {modfile!r}',
) )
self.source_in_third_paths.add(canonical_path(source_for_file(modfile))) self.source_in_third_paths.add(canonical_path(source_for_file(modfile)))
else: else:
for pathdir in path: for pathdir in path:
if self.third_match.match(pathdir): if self.third_match.match(pathdir):
_debug( _debug(
f"Source in third-party: {pkg!r} path directory at {pathdir!r}", f'Source in third-party: {pkg!r} path directory at {pathdir!r}',
) )
self.source_in_third_paths.add(pathdir) self.source_in_third_paths.add(pathdir)
for src in self.source: for src in self.source:
if self.third_match.match(src): if self.third_match.match(src):
_debug(f"Source in third-party: source directory {src!r}") _debug(f'Source in third-party: source directory {src!r}')
self.source_in_third_paths.add(src) self.source_in_third_paths.add(src)
self.source_in_third_match = TreeMatcher(self.source_in_third_paths, "source_in_third") self.source_in_third_match = TreeMatcher(self.source_in_third_paths, 'source_in_third')
_debug(f"Source in third-party matching: {self.source_in_third_match}") _debug(f'Source in third-party matching: {self.source_in_third_match}')
self.plugins: Plugins self.plugins: Plugins
self.disp_class: type[TFileDisposition] = FileDisposition self.disp_class: type[TFileDisposition] = FileDisposition
@ -309,8 +318,8 @@ class InOrOut:
disp.reason = reason disp.reason = reason
return disp return disp
if original_filename.startswith("<"): if original_filename.startswith('<'):
return nope(disp, "original file name is not real") return nope(disp, 'original file name is not real')
if frame is not None: if frame is not None:
# Compiled Python files have two file names: frame.f_code.co_filename is # Compiled Python files have two file names: frame.f_code.co_filename is
@ -319,10 +328,10 @@ class InOrOut:
# .pyc files can be moved after compilation (for example, by being # .pyc files can be moved after compilation (for example, by being
# installed), we look for __file__ in the frame and prefer it to the # installed), we look for __file__ in the frame and prefer it to the
# co_filename value. # co_filename value.
dunder_file = frame.f_globals and frame.f_globals.get("__file__") dunder_file = frame.f_globals and frame.f_globals.get('__file__')
if dunder_file: if dunder_file:
filename = source_for_file(dunder_file) filename = source_for_file(dunder_file)
if original_filename and not original_filename.startswith("<"): if original_filename and not original_filename.startswith('<'):
orig = os.path.basename(original_filename) orig = os.path.basename(original_filename)
if orig != os.path.basename(filename): if orig != os.path.basename(filename):
# Files shouldn't be renamed when moved. This happens when # Files shouldn't be renamed when moved. This happens when
@ -334,15 +343,15 @@ class InOrOut:
# Empty string is pretty useless. # Empty string is pretty useless.
return nope(disp, "empty string isn't a file name") return nope(disp, "empty string isn't a file name")
if filename.startswith("memory:"): if filename.startswith('memory:'):
return nope(disp, "memory isn't traceable") return nope(disp, "memory isn't traceable")
if filename.startswith("<"): if filename.startswith('<'):
# Lots of non-file execution is represented with artificial # Lots of non-file execution is represented with artificial
# file names like "<string>", "<doctest readme.txt[0]>", or # file names like "<string>", "<doctest readme.txt[0]>", or
# "<exec_function>". Don't ever trace these executions, since we # "<exec_function>". Don't ever trace these executions, since we
# can't do anything with the data later anyway. # can't do anything with the data later anyway.
return nope(disp, "file name is not real") return nope(disp, 'file name is not real')
canonical = canonical_filename(filename) canonical = canonical_filename(filename)
disp.canonical_filename = canonical disp.canonical_filename = canonical
@ -369,7 +378,7 @@ class InOrOut:
except Exception: except Exception:
plugin_name = plugin._coverage_plugin_name plugin_name = plugin._coverage_plugin_name
tb = traceback.format_exc() tb = traceback.format_exc()
self.warn(f"Disabling plug-in {plugin_name!r} due to an exception:\n{tb}") self.warn(f'Disabling plug-in {plugin_name!r} due to an exception:\n{tb}')
plugin._coverage_enabled = False plugin._coverage_enabled = False
continue continue
else: else:
@ -402,7 +411,7 @@ class InOrOut:
# any canned exclusions. If they didn't, then we have to exclude the # any canned exclusions. If they didn't, then we have to exclude the
# stdlib and coverage.py directories. # stdlib and coverage.py directories.
if self.source_match or self.source_pkgs_match: if self.source_match or self.source_pkgs_match:
extra = "" extra = ''
ok = False ok = False
if self.source_pkgs_match: if self.source_pkgs_match:
if self.source_pkgs_match.match(modulename): if self.source_pkgs_match.match(modulename):
@ -410,41 +419,41 @@ class InOrOut:
if modulename in self.source_pkgs_unmatched: if modulename in self.source_pkgs_unmatched:
self.source_pkgs_unmatched.remove(modulename) self.source_pkgs_unmatched.remove(modulename)
else: else:
extra = f"module {modulename!r} " extra = f'module {modulename!r} '
if not ok and self.source_match: if not ok and self.source_match:
if self.source_match.match(filename): if self.source_match.match(filename):
ok = True ok = True
if not ok: if not ok:
return extra + "falls outside the --source spec" return extra + 'falls outside the --source spec'
if self.third_match.match(filename) and not self.source_in_third_match.match(filename): if self.third_match.match(filename) and not self.source_in_third_match.match(filename):
return "inside --source, but is third-party" return 'inside --source, but is third-party'
elif self.include_match: elif self.include_match:
if not self.include_match.match(filename): if not self.include_match.match(filename):
return "falls outside the --include trees" return 'falls outside the --include trees'
else: else:
# We exclude the coverage.py code itself, since a little of it # We exclude the coverage.py code itself, since a little of it
# will be measured otherwise. # will be measured otherwise.
if self.cover_match.match(filename): if self.cover_match.match(filename):
return "is part of coverage.py" return 'is part of coverage.py'
# If we aren't supposed to trace installed code, then check if this # If we aren't supposed to trace installed code, then check if this
# is near the Python standard library and skip it if so. # is near the Python standard library and skip it if so.
if self.pylib_match and self.pylib_match.match(filename): if self.pylib_match and self.pylib_match.match(filename):
return "is in the stdlib" return 'is in the stdlib'
# Exclude anything in the third-party installation areas. # Exclude anything in the third-party installation areas.
if self.third_match.match(filename): if self.third_match.match(filename):
return "is a third-party module" return 'is a third-party module'
# Check the file against the omit pattern. # Check the file against the omit pattern.
if self.omit_match and self.omit_match.match(filename): if self.omit_match and self.omit_match.match(filename):
return "is inside an --omit pattern" return 'is inside an --omit pattern'
# No point tracing a file we can't later write to SQLite. # No point tracing a file we can't later write to SQLite.
try: try:
filename.encode("utf-8") filename.encode('utf-8')
except UnicodeEncodeError: except UnicodeEncodeError:
return "non-encodable filename" return 'non-encodable filename'
# No reason found to skip this file. # No reason found to skip this file.
return None return None
@ -453,20 +462,20 @@ class InOrOut:
"""Warn if there are settings that conflict.""" """Warn if there are settings that conflict."""
if self.include: if self.include:
if self.source or self.source_pkgs: if self.source or self.source_pkgs:
self.warn("--include is ignored because --source is set", slug="include-ignored") self.warn('--include is ignored because --source is set', slug='include-ignored')
def warn_already_imported_files(self) -> None: def warn_already_imported_files(self) -> None:
"""Warn if files have already been imported that we will be measuring.""" """Warn if files have already been imported that we will be measuring."""
if self.include or self.source or self.source_pkgs: if self.include or self.source or self.source_pkgs:
warned = set() warned = set()
for mod in list(sys.modules.values()): for mod in list(sys.modules.values()):
filename = getattr(mod, "__file__", None) filename = getattr(mod, '__file__', None)
if filename is None: if filename is None:
continue continue
if filename in warned: if filename in warned:
continue continue
if len(getattr(mod, "__path__", ())) > 1: if len(getattr(mod, '__path__', ())) > 1:
# A namespace package, which confuses this code, so ignore it. # A namespace package, which confuses this code, so ignore it.
continue continue
@ -477,10 +486,10 @@ class InOrOut:
# of tracing anyway. # of tracing anyway.
continue continue
if disp.trace: if disp.trace:
msg = f"Already imported a file that will be measured: {filename}" msg = f'Already imported a file that will be measured: {filename}'
self.warn(msg, slug="already-imported") self.warn(msg, slug='already-imported')
warned.add(filename) warned.add(filename)
elif self.debug and self.debug.should("trace"): elif self.debug and self.debug.should('trace'):
self.debug.write( self.debug.write(
"Didn't trace already imported file {!r}: {}".format( "Didn't trace already imported file {!r}: {}".format(
disp.original_filename, disp.reason, disp.original_filename, disp.reason,
@ -500,7 +509,7 @@ class InOrOut:
""" """
mod = sys.modules.get(pkg) mod = sys.modules.get(pkg)
if mod is None: if mod is None:
self.warn(f"Module {pkg} was never imported.", slug="module-not-imported") self.warn(f'Module {pkg} was never imported.', slug='module-not-imported')
return return
if module_is_namespace(mod): if module_is_namespace(mod):
@ -509,14 +518,14 @@ class InOrOut:
return return
if not module_has_file(mod): if not module_has_file(mod):
self.warn(f"Module {pkg} has no Python source.", slug="module-not-python") self.warn(f'Module {pkg} has no Python source.', slug='module-not-python')
return return
# The module was in sys.modules, and seems like a module with code, but # The module was in sys.modules, and seems like a module with code, but
# we never measured it. I guess that means it was imported before # we never measured it. I guess that means it was imported before
# coverage even started. # coverage even started.
msg = f"Module {pkg} was previously imported, but not measured" msg = f'Module {pkg} was previously imported, but not measured'
self.warn(msg, slug="module-not-measured") self.warn(msg, slug='module-not-measured')
def find_possibly_unexecuted_files(self) -> Iterable[tuple[str, str | None]]: def find_possibly_unexecuted_files(self) -> Iterable[tuple[str, str | None]]:
"""Find files in the areas of interest that might be untraced. """Find files in the areas of interest that might be untraced.
@ -524,8 +533,10 @@ class InOrOut:
Yields pairs: file path, and responsible plug-in name. Yields pairs: file path, and responsible plug-in name.
""" """
for pkg in self.source_pkgs: for pkg in self.source_pkgs:
if (pkg not in sys.modules or if (
not module_has_file(sys.modules[pkg])): pkg not in sys.modules or
not module_has_file(sys.modules[pkg])
):
continue continue
pkg_file = source_for_file(cast(str, sys.modules[pkg].__file__)) pkg_file = source_for_file(cast(str, sys.modules[pkg].__file__))
yield from self._find_executable_files(canonical_path(pkg_file)) yield from self._find_executable_files(canonical_path(pkg_file))
@ -569,16 +580,16 @@ class InOrOut:
Returns a list of (key, value) pairs. Returns a list of (key, value) pairs.
""" """
info = [ info = [
("coverage_paths", self.cover_paths), ('coverage_paths', self.cover_paths),
("stdlib_paths", self.pylib_paths), ('stdlib_paths', self.pylib_paths),
("third_party_paths", self.third_paths), ('third_party_paths', self.third_paths),
("source_in_third_party_paths", self.source_in_third_paths), ('source_in_third_party_paths', self.source_in_third_paths),
] ]
matcher_names = [ matcher_names = [
"source_match", "source_pkgs_match", 'source_match', 'source_pkgs_match',
"include_match", "omit_match", 'include_match', 'omit_match',
"cover_match", "pylib_match", "third_match", "source_in_third_match", 'cover_match', 'pylib_match', 'third_match', 'source_in_third_match',
] ]
for matcher_name in matcher_names: for matcher_name in matcher_names:
@ -586,7 +597,7 @@ class InOrOut:
if matcher: if matcher:
matcher_info = matcher.info() matcher_info = matcher.info()
else: else:
matcher_info = "-none-" matcher_info = '-none-'
info.append((matcher_name, matcher_info)) info.append((matcher_name, matcher_info))
return info return info

View file

@ -1,20 +1,22 @@
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt # For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt
"""Json reporting for coverage.py""" """Json reporting for coverage.py"""
from __future__ import annotations from __future__ import annotations
import datetime import datetime
import json import json
import sys import sys
from typing import Any
from typing import Any, IO, Iterable, TYPE_CHECKING from typing import IO
from typing import Iterable
from typing import TYPE_CHECKING
from coverage import __version__ from coverage import __version__
from coverage.report_core import get_analysis_to_report from coverage.report_core import get_analysis_to_report
from coverage.results import Analysis, Numbers from coverage.results import Analysis
from coverage.types import TMorf, TLineNo from coverage.results import Numbers
from coverage.types import TLineNo
from coverage.types import TMorf
if TYPE_CHECKING: if TYPE_CHECKING:
from coverage import Coverage from coverage import Coverage
@ -25,10 +27,11 @@ if TYPE_CHECKING:
# 2: add the meta.format field. # 2: add the meta.format field.
FORMAT_VERSION = 2 FORMAT_VERSION = 2
class JsonReporter: class JsonReporter:
"""A reporter for writing JSON coverage results.""" """A reporter for writing JSON coverage results."""
report_type = "JSON report" report_type = 'JSON report'
def __init__(self, coverage: Coverage) -> None: def __init__(self, coverage: Coverage) -> None:
self.coverage = coverage self.coverage = coverage
@ -47,12 +50,12 @@ class JsonReporter:
outfile = outfile or sys.stdout outfile = outfile or sys.stdout
coverage_data = self.coverage.get_data() coverage_data = self.coverage.get_data()
coverage_data.set_query_contexts(self.config.report_contexts) coverage_data.set_query_contexts(self.config.report_contexts)
self.report_data["meta"] = { self.report_data['meta'] = {
"format": FORMAT_VERSION, 'format': FORMAT_VERSION,
"version": __version__, 'version': __version__,
"timestamp": datetime.datetime.now().isoformat(), 'timestamp': datetime.datetime.now().isoformat(),
"branch_coverage": coverage_data.has_arcs(), 'branch_coverage': coverage_data.has_arcs(),
"show_contexts": self.config.json_show_contexts, 'show_contexts': self.config.json_show_contexts,
} }
measured_files = {} measured_files = {}
@ -62,23 +65,23 @@ class JsonReporter:
analysis, analysis,
) )
self.report_data["files"] = measured_files self.report_data['files'] = measured_files
self.report_data["totals"] = { self.report_data['totals'] = {
"covered_lines": self.total.n_executed, 'covered_lines': self.total.n_executed,
"num_statements": self.total.n_statements, 'num_statements': self.total.n_statements,
"percent_covered": self.total.pc_covered, 'percent_covered': self.total.pc_covered,
"percent_covered_display": self.total.pc_covered_str, 'percent_covered_display': self.total.pc_covered_str,
"missing_lines": self.total.n_missing, 'missing_lines': self.total.n_missing,
"excluded_lines": self.total.n_excluded, 'excluded_lines': self.total.n_excluded,
} }
if coverage_data.has_arcs(): if coverage_data.has_arcs():
self.report_data["totals"].update({ self.report_data['totals'].update({
"num_branches": self.total.n_branches, 'num_branches': self.total.n_branches,
"num_partial_branches": self.total.n_partial_branches, 'num_partial_branches': self.total.n_partial_branches,
"covered_branches": self.total.n_executed_branches, 'covered_branches': self.total.n_executed_branches,
"missing_branches": self.total.n_missing_branches, 'missing_branches': self.total.n_missing_branches,
}) })
json.dump( json.dump(
@ -94,32 +97,32 @@ class JsonReporter:
nums = analysis.numbers nums = analysis.numbers
self.total += nums self.total += nums
summary = { summary = {
"covered_lines": nums.n_executed, 'covered_lines': nums.n_executed,
"num_statements": nums.n_statements, 'num_statements': nums.n_statements,
"percent_covered": nums.pc_covered, 'percent_covered': nums.pc_covered,
"percent_covered_display": nums.pc_covered_str, 'percent_covered_display': nums.pc_covered_str,
"missing_lines": nums.n_missing, 'missing_lines': nums.n_missing,
"excluded_lines": nums.n_excluded, 'excluded_lines': nums.n_excluded,
} }
reported_file = { reported_file = {
"executed_lines": sorted(analysis.executed), 'executed_lines': sorted(analysis.executed),
"summary": summary, 'summary': summary,
"missing_lines": sorted(analysis.missing), 'missing_lines': sorted(analysis.missing),
"excluded_lines": sorted(analysis.excluded), 'excluded_lines': sorted(analysis.excluded),
} }
if self.config.json_show_contexts: if self.config.json_show_contexts:
reported_file["contexts"] = analysis.data.contexts_by_lineno(analysis.filename) reported_file['contexts'] = analysis.data.contexts_by_lineno(analysis.filename)
if coverage_data.has_arcs(): if coverage_data.has_arcs():
summary.update({ summary.update({
"num_branches": nums.n_branches, 'num_branches': nums.n_branches,
"num_partial_branches": nums.n_partial_branches, 'num_partial_branches': nums.n_partial_branches,
"covered_branches": nums.n_executed_branches, 'covered_branches': nums.n_executed_branches,
"missing_branches": nums.n_missing_branches, 'missing_branches': nums.n_missing_branches,
}) })
reported_file["executed_branches"] = list( reported_file['executed_branches'] = list(
_convert_branch_arcs(analysis.executed_branch_arcs()), _convert_branch_arcs(analysis.executed_branch_arcs()),
) )
reported_file["missing_branches"] = list( reported_file['missing_branches'] = list(
_convert_branch_arcs(analysis.missing_branch_arcs()), _convert_branch_arcs(analysis.missing_branch_arcs()),
) )
return reported_file return reported_file

Some files were not shown because too many files have changed in this diff Show more