from __future__ import annotations from collections.abc import Callable from collections.abc import Sequence from functools import partial from inspect import getmro from inspect import isclass from typing import cast from typing import Generic from typing import overload from typing import Type from typing import TYPE_CHECKING from typing import TypeVar if TYPE_CHECKING: from typing import Self _BaseExceptionT_co = TypeVar('_BaseExceptionT_co', bound=BaseException, covariant=True) _BaseExceptionT = TypeVar('_BaseExceptionT', bound=BaseException) _ExceptionT_co = TypeVar('_ExceptionT_co', bound=Exception, covariant=True) _ExceptionT = TypeVar('_ExceptionT', bound=Exception) def check_direct_subclass( exc: BaseException, parents: tuple[type[BaseException]], ) -> bool: for cls in getmro(exc.__class__)[:-1]: if cls in parents: return True return False def get_condition_filter( condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...] | Callable[[_BaseExceptionT_co], bool], ) -> Callable[[_BaseExceptionT_co], bool]: if isclass(condition) and issubclass( cast(Type[BaseException], condition), BaseException, ): return partial(check_direct_subclass, parents=(condition,)) elif isinstance(condition, tuple): if all(isclass(x) and issubclass(x, BaseException) for x in condition): return partial(check_direct_subclass, parents=condition) elif callable(condition): return cast('Callable[[BaseException], bool]', condition) raise TypeError('expected a function, exception type or tuple of exception types') class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]): """A combination of multiple unrelated exceptions.""" def __new__( cls, __message: str, __exceptions: Sequence[_BaseExceptionT_co], ) -> Self: if not isinstance(__message, str): raise TypeError(f'argument 1 must be str, not {type(__message)}') if not isinstance(__exceptions, Sequence): raise TypeError('second argument (exceptions) must be a sequence') if not __exceptions: raise ValueError( 'second argument (exceptions) must be a non-empty sequence', ) for i, exc in enumerate(__exceptions): if not isinstance(exc, BaseException): raise ValueError( f'Item {i} of second argument (exceptions) is not an exception', ) if cls is BaseExceptionGroup: if all(isinstance(exc, Exception) for exc in __exceptions): cls = ExceptionGroup if issubclass(cls, Exception): for exc in __exceptions: if not isinstance(exc, Exception): if cls is ExceptionGroup: raise TypeError( 'Cannot nest BaseExceptions in an ExceptionGroup', ) else: raise TypeError( f'Cannot nest BaseExceptions in {cls.__name__!r}', ) instance = super().__new__(cls, __message, __exceptions) instance._message = __message instance._exceptions = __exceptions return instance def add_note(self, note: str) -> None: if not isinstance(note, str): raise TypeError( f'Expected a string, got note={note!r} (type {type(note).__name__})', ) if not hasattr(self, '__notes__'): self.__notes__: list[str] = [] self.__notes__.append(note) @property def message(self) -> str: return self._message @property def exceptions( self, ) -> tuple[_BaseExceptionT_co | BaseExceptionGroup[_BaseExceptionT_co], ...]: return tuple(self._exceptions) @overload def subgroup( self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...], ) -> ExceptionGroup[_ExceptionT] | None: ... @overload def subgroup( self, __condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...], ) -> BaseExceptionGroup[_BaseExceptionT] | None: ... @overload def subgroup( self, __condition: Callable[[_BaseExceptionT_co | Self], bool], ) -> BaseExceptionGroup[_BaseExceptionT_co] | None: ... def subgroup( self, __condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...] | Callable[[_BaseExceptionT_co | Self], bool], ) -> BaseExceptionGroup[_BaseExceptionT] | None: condition = get_condition_filter(__condition) modified = False if condition(self): return self exceptions: list[BaseException] = [] for exc in self.exceptions: if isinstance(exc, BaseExceptionGroup): subgroup = exc.subgroup(__condition) if subgroup is not None: exceptions.append(subgroup) if subgroup is not exc: modified = True elif condition(exc): exceptions.append(exc) else: modified = True if not modified: return self elif exceptions: group = self.derive(exceptions) group.__cause__ = self.__cause__ group.__context__ = self.__context__ group.__traceback__ = self.__traceback__ return group else: return None @overload def split( self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...], ) -> tuple[ ExceptionGroup[_ExceptionT] | None, BaseExceptionGroup[_BaseExceptionT_co] | None, ]: ... @overload def split( self, __condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...], ) -> tuple[ BaseExceptionGroup[_BaseExceptionT] | None, BaseExceptionGroup[_BaseExceptionT_co] | None, ]: ... @overload def split( self, __condition: Callable[[_BaseExceptionT_co | Self], bool], ) -> tuple[ BaseExceptionGroup[_BaseExceptionT_co] | None, BaseExceptionGroup[_BaseExceptionT_co] | None, ]: ... def split( self, __condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...] | Callable[[_BaseExceptionT_co], bool], ) -> ( tuple[ ExceptionGroup[_ExceptionT] | None, BaseExceptionGroup[_BaseExceptionT_co] | None, ] | tuple[ BaseExceptionGroup[_BaseExceptionT] | None, BaseExceptionGroup[_BaseExceptionT_co] | None, ] | tuple[ BaseExceptionGroup[_BaseExceptionT_co] | None, BaseExceptionGroup[_BaseExceptionT_co] | None, ] ): condition = get_condition_filter(__condition) if condition(self): return self, None matching_exceptions: list[BaseException] = [] nonmatching_exceptions: list[BaseException] = [] for exc in self.exceptions: if isinstance(exc, BaseExceptionGroup): matching, nonmatching = exc.split(condition) if matching is not None: matching_exceptions.append(matching) if nonmatching is not None: nonmatching_exceptions.append(nonmatching) elif condition(exc): matching_exceptions.append(exc) else: nonmatching_exceptions.append(exc) matching_group: Self | None = None if matching_exceptions: matching_group = self.derive(matching_exceptions) matching_group.__cause__ = self.__cause__ matching_group.__context__ = self.__context__ matching_group.__traceback__ = self.__traceback__ nonmatching_group: Self | None = None if nonmatching_exceptions: nonmatching_group = self.derive(nonmatching_exceptions) nonmatching_group.__cause__ = self.__cause__ nonmatching_group.__context__ = self.__context__ nonmatching_group.__traceback__ = self.__traceback__ return matching_group, nonmatching_group @overload def derive(self, __excs: Sequence[_ExceptionT]) -> ExceptionGroup[_ExceptionT]: ... @overload def derive( self, __excs: Sequence[_BaseExceptionT], ) -> BaseExceptionGroup[_BaseExceptionT]: ... def derive( self, __excs: Sequence[_BaseExceptionT], ) -> BaseExceptionGroup[_BaseExceptionT]: eg = BaseExceptionGroup(self.message, __excs) if hasattr(self, '__notes__'): # Create a new list so that add_note() only affects one exceptiongroup eg.__notes__ = list(self.__notes__) return eg def __str__(self) -> str: suffix = '' if len(self._exceptions) == 1 else 's' return f'{self.message} ({len(self._exceptions)} sub-exception{suffix})' def __repr__(self) -> str: return f'{self.__class__.__name__}({self.message!r}, {self._exceptions!r})' class ExceptionGroup(BaseExceptionGroup[_ExceptionT_co], Exception): def __new__(cls, __message: str, __exceptions: Sequence[_ExceptionT_co]) -> Self: return super().__new__(cls, __message, __exceptions) if TYPE_CHECKING: @property def exceptions( self, ) -> tuple[_ExceptionT_co | ExceptionGroup[_ExceptionT_co], ...]: ... @overload # type: ignore[override] def subgroup( self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...], ) -> ExceptionGroup[_ExceptionT] | None: ... @overload def subgroup( self, __condition: Callable[[_ExceptionT_co | Self], bool], ) -> ExceptionGroup[_ExceptionT_co] | None: ... def subgroup( self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...] | Callable[[_ExceptionT_co], bool], ) -> ExceptionGroup[_ExceptionT] | None: return super().subgroup(__condition) @overload def split( self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...], ) -> tuple[ ExceptionGroup[_ExceptionT] | None, ExceptionGroup[_ExceptionT_co] | None, ]: ... @overload def split( self, __condition: Callable[[_ExceptionT_co | Self], bool], ) -> tuple[ ExceptionGroup[_ExceptionT_co] | None, ExceptionGroup[_ExceptionT_co] | None, ]: ... def split( self: Self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...] | Callable[[_ExceptionT_co], bool], ) -> tuple[ ExceptionGroup[_ExceptionT_co] | None, ExceptionGroup[_ExceptionT_co] | None, ]: return super().split(__condition)