diff --git a/testtools/matchers/_basic.py b/testtools/matchers/_basic.py index d3650ef3..c326777a 100644 --- a/testtools/matchers/_basic.py +++ b/testtools/matchers/_basic.py @@ -18,7 +18,7 @@ import operator import re -from collections.abc import Callable +from collections.abc import Callable, Sized from pprint import pformat from typing import Any, Generic, TypeVar @@ -459,7 +459,7 @@ def match(self, value: str) -> Mismatch | None: return None -def has_len(x: Any, y: int) -> bool: +def has_len(x: Sized, y: int) -> bool: return len(x) == y diff --git a/testtools/monkey.py b/testtools/monkey.py index 51669e58..ce8221e4 100644 --- a/testtools/monkey.py +++ b/testtools/monkey.py @@ -3,7 +3,10 @@ """Helpers for monkey-patching Python code.""" from collections.abc import Callable -from typing import Any +from typing import ParamSpec, TypeVar + +_P = ParamSpec("_P") +_R = TypeVar("_R") __all__ = [ "MonkeyPatcher", @@ -71,7 +74,9 @@ def restore(self) -> None: else: setattr(obj, name, value) - def run_with_patches(self, f: Callable[..., Any], *args: Any, **kw: Any) -> Any: + def run_with_patches( + self, f: Callable[_P, _R], *args: _P.args, **kw: _P.kwargs + ) -> _R: """Run 'f' with the given args and kwargs with all patches applied. Restores all objects to their original state when finished. diff --git a/testtools/runtest.py b/testtools/runtest.py index 06743ccc..053ec739 100644 --- a/testtools/runtest.py +++ b/testtools/runtest.py @@ -9,7 +9,7 @@ import sys from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NoReturn from testtools.testresult import ( ExcInfo, @@ -252,7 +252,7 @@ def _got_user_exception( return self.exception_caught -def _raise_force_fail_error() -> None: +def _raise_force_fail_error() -> NoReturn: raise AssertionError("Forced Test Failure") diff --git a/testtools/testcase.py b/testtools/testcase.py index a5dc8efd..8f3fbef0 100644 --- a/testtools/testcase.py +++ b/testtools/testcase.py @@ -23,11 +23,14 @@ import types import unittest from collections.abc import Callable, Iterator -from typing import TYPE_CHECKING, NoReturn, ParamSpec, TypeVar, cast, overload +from typing import TYPE_CHECKING, Generic, NoReturn, ParamSpec, TypeVar, cast, overload from unittest.case import SkipTest T = TypeVar("T") U = TypeVar("U") +_E = TypeVar("_E", bound=BaseException) +_E2 = TypeVar("_E2", bound=BaseException) +_E3 = TypeVar("_E3", bound=BaseException) # ruff: noqa: E402 - TypeVars must be defined before importing testtools modules from testtools import content @@ -513,17 +516,38 @@ def assertRaises( @overload # type: ignore[override] def assertRaises( self, - expected_exception: type[BaseException] | tuple[type[BaseException]], + expected_exception: type[_E], callable: None = ..., - ) -> "_AssertRaisesContext": ... + ) -> "_AssertRaisesContext[_E]": ... - def assertRaises( # type: ignore[override] + @overload # type: ignore[override] + def assertRaises( self, - expected_exception: type[BaseException] | tuple[type[BaseException]], + expected_exception: tuple[type[_E], type[_E2]], + callable: None = ..., + ) -> "_AssertRaisesContext[_E | _E2]": ... + + @overload # type: ignore[override] + def assertRaises( + self, + expected_exception: tuple[type[_E], type[_E2], type[_E3]], + callable: None = ..., + ) -> "_AssertRaisesContext[_E | _E2 | _E3]": ... + + @overload # type: ignore[override] + def assertRaises( + self, + expected_exception: tuple[type[BaseException], ...], + callable: None = ..., + ) -> "_AssertRaisesContext[BaseException]": ... + + def assertRaises( # type: ignore[override, misc] + self, + expected_exception: type[BaseException] | tuple[type[BaseException], ...], callable: Callable[_P, _R] | None = None, *args: _P.args, **kwargs: _P.kwargs, - ) -> "_AssertRaisesContext | BaseException": + ) -> "_AssertRaisesContext[BaseException] | BaseException": """Fail unless an exception of class expected_exception is thrown by callable when invoked with arguments args and keyword arguments kwargs. If a different type of exception is @@ -1185,7 +1209,7 @@ def decorator(test_item: _F) -> _F: if not isinstance(test_item, class_types): @functools.wraps(test_item) - def skip_wrapper(*args: object, **kwargs: object) -> None: + def skip_wrapper(*args: object, **kwargs: object) -> NoReturn: raise TestCase.skipException(reason) test_item = cast(_F, skip_wrapper) @@ -1223,7 +1247,7 @@ def _id(obj: _F) -> _F: return _id -class _AssertRaisesContext: +class _AssertRaisesContext(Generic[_E]): """A context manager to handle expected exceptions for assertRaises. This provides compatibility with unittest's assertRaises context manager. @@ -1231,7 +1255,7 @@ class _AssertRaisesContext: def __init__( self, - expected: type[BaseException] | tuple[type[BaseException]], + expected: type[_E] | tuple[type[BaseException], ...], test_case: TestCase, msg: str | None = None, ) -> None: @@ -1244,7 +1268,7 @@ def __init__( self.expected = expected self.test_case = test_case self.msg = msg - self.exception: BaseException | None = None + self.exception: _E | None = None def __enter__(self) -> "Self": return self @@ -1274,7 +1298,7 @@ def __exit__( # let unexpected exceptions pass through return False # store exception for later retrieval - self.exception = exc_value + self.exception = cast(_E, exc_value) return True @@ -1341,7 +1365,7 @@ def __exit__( return True -class Nullary: +class Nullary(Generic[_R]): """Turn a callable into a nullary callable. The advantage of this over ``lambda: f(*args, **kwargs)`` is that it @@ -1358,7 +1382,7 @@ def __init__( self._args = args self._kwargs = kwargs - def __call__(self) -> object: + def __call__(self) -> _R: return self._callable_object(*self._args, **self._kwargs) def __repr__(self) -> str: