Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions testtools/matchers/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import operator
import re
from collections.abc import Callable
from collections.abc import Callable, Sized
from pprint import pformat
from typing import Any, Generic, TypeVar

Expand Down Expand Up @@ -459,7 +459,7 @@ def match(self, value: str) -> Mismatch | None:
return None


def has_len(x: Any, y: int) -> bool:
def has_len(x: Sized, y: int) -> bool:
return len(x) == y


Expand Down
9 changes: 7 additions & 2 deletions testtools/monkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
"""Helpers for monkey-patching Python code."""

from collections.abc import Callable
from typing import Any
from typing import ParamSpec, TypeVar

_P = ParamSpec("_P")
_R = TypeVar("_R")

__all__ = [
"MonkeyPatcher",
Expand Down Expand Up @@ -71,7 +74,9 @@ def restore(self) -> None:
else:
setattr(obj, name, value)

def run_with_patches(self, f: Callable[..., Any], *args: Any, **kw: Any) -> Any:
def run_with_patches(
self, f: Callable[_P, _R], *args: _P.args, **kw: _P.kwargs
) -> _R:
"""Run 'f' with the given args and kwargs with all patches applied.

Restores all objects to their original state when finished.
Expand Down
4 changes: 2 additions & 2 deletions testtools/runtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import sys
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NoReturn

from testtools.testresult import (
ExcInfo,
Expand Down Expand Up @@ -252,7 +252,7 @@ def _got_user_exception(
return self.exception_caught


def _raise_force_fail_error() -> None:
def _raise_force_fail_error() -> NoReturn:
raise AssertionError("Forced Test Failure")


Expand Down
50 changes: 37 additions & 13 deletions testtools/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
import types
import unittest
from collections.abc import Callable, Iterator
from typing import TYPE_CHECKING, NoReturn, ParamSpec, TypeVar, cast, overload
from typing import TYPE_CHECKING, Generic, NoReturn, ParamSpec, TypeVar, cast, overload
from unittest.case import SkipTest

T = TypeVar("T")
U = TypeVar("U")
_E = TypeVar("_E", bound=BaseException)
_E2 = TypeVar("_E2", bound=BaseException)
_E3 = TypeVar("_E3", bound=BaseException)

# ruff: noqa: E402 - TypeVars must be defined before importing testtools modules
from testtools import content
Expand Down Expand Up @@ -513,17 +516,38 @@ def assertRaises(
@overload # type: ignore[override]
def assertRaises(
self,
expected_exception: type[BaseException] | tuple[type[BaseException]],
expected_exception: type[_E],
callable: None = ...,
) -> "_AssertRaisesContext": ...
) -> "_AssertRaisesContext[_E]": ...

def assertRaises( # type: ignore[override]
@overload # type: ignore[override]
def assertRaises(
self,
expected_exception: type[BaseException] | tuple[type[BaseException]],
expected_exception: tuple[type[_E], type[_E2]],
callable: None = ...,
) -> "_AssertRaisesContext[_E | _E2]": ...

@overload # type: ignore[override]
def assertRaises(
self,
expected_exception: tuple[type[_E], type[_E2], type[_E3]],
callable: None = ...,
) -> "_AssertRaisesContext[_E | _E2 | _E3]": ...

@overload # type: ignore[override]
def assertRaises(
self,
expected_exception: tuple[type[BaseException], ...],
callable: None = ...,
) -> "_AssertRaisesContext[BaseException]": ...

def assertRaises( # type: ignore[override, misc]
self,
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
callable: Callable[_P, _R] | None = None,
*args: _P.args,
**kwargs: _P.kwargs,
) -> "_AssertRaisesContext | BaseException":
) -> "_AssertRaisesContext[BaseException] | BaseException":
"""Fail unless an exception of class expected_exception is thrown
by callable when invoked with arguments args and keyword
arguments kwargs. If a different type of exception is
Expand Down Expand Up @@ -1185,7 +1209,7 @@ def decorator(test_item: _F) -> _F:
if not isinstance(test_item, class_types):

@functools.wraps(test_item)
def skip_wrapper(*args: object, **kwargs: object) -> None:
def skip_wrapper(*args: object, **kwargs: object) -> NoReturn:
raise TestCase.skipException(reason)

test_item = cast(_F, skip_wrapper)
Expand Down Expand Up @@ -1223,15 +1247,15 @@ def _id(obj: _F) -> _F:
return _id


class _AssertRaisesContext:
class _AssertRaisesContext(Generic[_E]):
"""A context manager to handle expected exceptions for assertRaises.

This provides compatibility with unittest's assertRaises context manager.
"""

def __init__(
self,
expected: type[BaseException] | tuple[type[BaseException]],
expected: type[_E] | tuple[type[BaseException], ...],
test_case: TestCase,
msg: str | None = None,
) -> None:
Expand All @@ -1244,7 +1268,7 @@ def __init__(
self.expected = expected
self.test_case = test_case
self.msg = msg
self.exception: BaseException | None = None
self.exception: _E | None = None

def __enter__(self) -> "Self":
return self
Expand Down Expand Up @@ -1274,7 +1298,7 @@ def __exit__(
# let unexpected exceptions pass through
return False
# store exception for later retrieval
self.exception = exc_value
self.exception = cast(_E, exc_value)
return True


Expand Down Expand Up @@ -1341,7 +1365,7 @@ def __exit__(
return True


class Nullary:
class Nullary(Generic[_R]):
"""Turn a callable into a nullary callable.

The advantage of this over ``lambda: f(*args, **kwargs)`` is that it
Expand All @@ -1358,7 +1382,7 @@ def __init__(
self._args = args
self._kwargs = kwargs

def __call__(self) -> object:
def __call__(self) -> _R:
return self._callable_object(*self._args, **self._kwargs)

def __repr__(self) -> str:
Expand Down