diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9990caaeb7a1..bef88b135429 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -6748,8 +6748,55 @@ def any_causes_overload_ambiguity( for formal in formals: matching_formals.append(matched_callable.arg_types[formal]) if not all_same_types(matching_formals) and not all_same_types(matching_returns): - # Any maps to multiple different types, and the return types of these items differ. + if _any_in_unrelated_position(get_proper_type(arg_type), matching_formals): + continue + return True + return False + + +def _any_in_unrelated_position(arg_type: ProperType, formals: list[Type]) -> bool: + """Whether Any in arg_type doesn't affect overload differentiation.""" + if isinstance(arg_type, AnyType): + return False + + proper_formals = [get_proper_type(f) for f in formals] + if len(proper_formals) < 2: + return False + + if isinstance(arg_type, TupleType): + if not all( + isinstance(f, TupleType) and len(f.items) == len(arg_type.items) + for f in proper_formals + ): + return False + tuple_formals = cast(list[TupleType], proper_formals) + for i, item in enumerate(arg_type.items): + item_proper = get_proper_type(item) + if has_any_type(item_proper): + formal_items = [tf.items[i] for tf in tuple_formals] + if all_same_types(formal_items): + continue + if not _any_in_unrelated_position(item_proper, formal_items): + return False + return True + + if isinstance(arg_type, CallableType): + if not all(isinstance(f, CallableType) for f in proper_formals): + return False + formal_callables = cast(list[CallableType], proper_formals) + first_args = formal_callables[0].arg_types + for fc in formal_callables[1:]: + if len(fc.arg_types) != len(first_args): + return False + for a1, a2 in zip(first_args, fc.arg_types): + if not is_same_type(a1, a2): + return False + return_types = [fc.ret_type for fc in formal_callables] + for candidate in return_types: + if all(is_subtype(candidate, other) for other in return_types): return True + return False + return False diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 68eccb46ab21..8d8017a37a3d 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -1906,6 +1906,73 @@ a: Any reveal_type(f(a)) # N: Revealed type is "def (*Any, **Any) -> Any" reveal_type(f(a)(a)) # N: Revealed type is "Any" +[case testOverloadWithNestedAnyInCallableDoesNotCauseAmbiguity] +# https://github.com/python/mypy/issues/20442 +from typing import overload, Any, Generic, Callable +from typing_extensions import TypeVar + +T = TypeVar("T") + +class MyClass(Generic[T]): + pass + +class Result1: + pass + +class Result2: + pass + +@overload +def test(x: Callable[[MyClass[Any]], int]) -> Result1: ... +@overload +def test(x: Callable[[MyClass[Any]], Any]) -> Result2: ... +def test(x: object) -> object: + return x + +def fn1(c: MyClass[Any]) -> int: + return 1 + +reveal_type(test(fn1)) # N: Revealed type is "__main__.Result1" +[builtins fixtures/tuple.pyi] + +[case testOverloadWithNestedAnyInTupleCallableDoesNotCauseAmbiguity] +# https://github.com/python/mypy/issues/20442 +from typing import overload, Any, Generic, Callable, Iterable +from collections.abc import Hashable +from typing_extensions import TypeVar, Self + +IndexT0 = TypeVar("IndexT0") +T = TypeVar("T") + +class Series: + pass + +class _LocIndexerFrame(Generic[T]): + @overload + def __getitem__(self, idx: tuple[Callable[[DataFrame], int], str]) -> int: ... + @overload + def __getitem__(self, idx: tuple[Callable[[DataFrame], list[Hashable]], str]) -> Series: ... + @overload + def __getitem__(self, idx: tuple[Callable[[DataFrame], Any], Iterable[Hashable]]) -> T: ... + def __getitem__(self, idx: object) -> object: + return idx + +class DataFrame(Generic[IndexT0]): + @property + def loc(self) -> _LocIndexerFrame[Self]: + return _LocIndexerFrame() + +def select2(df: DataFrame[Any]) -> list[Hashable]: + return [] + +def select3(_: DataFrame[Any]) -> int: + return 1 + +reveal_type(DataFrame().loc[select2, "x"]) # N: Revealed type is "__main__.Series" +reveal_type(DataFrame().loc[select3, "x"]) # N: Revealed type is "builtins.int" +[builtins fixtures/property.pyi] +[typing fixtures/typing-full.pyi] + [case testOverloadOnOverloadWithType] from typing import Any, Type, TypeVar, overload from mod import MyInt