diff --git a/mypy/join.py b/mypy/join.py index a074fa522588..a8c9910e60bb 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -377,24 +377,34 @@ def visit_instance(self, t: Instance) -> ProperType: return self.default(self.s) def visit_callable_type(self, t: CallableType) -> ProperType: - if isinstance(self.s, CallableType) and is_similar_callables(t, self.s): - if is_equivalent(t, self.s): - return combine_similar_callables(t, self.s) - result = join_similar_callables(t, self.s) - # We set the from_type_type flag to suppress error when a collection of - # concrete class objects gets inferred as their common abstract superclass. - if not ( - (t.is_type_obj() and t.type_object().is_abstract) - or (self.s.is_type_obj() and self.s.type_object().is_abstract) - ): - result.from_type_type = True - if any( - isinstance(tp, (NoneType, UninhabitedType)) - for tp in get_proper_types(result.arg_types) - ): - # We don't want to return unusable Callable, attempt fallback instead. + if isinstance(self.s, CallableType): + if is_similar_callables(t, self.s): + if is_equivalent(t, self.s): + return combine_similar_callables(t, self.s) + result = join_similar_callables(t, self.s) + if any( + isinstance(tp, (NoneType, UninhabitedType)) + for tp in get_proper_types(result.arg_types) + ): + # We don't want to return unusable Callable, attempt fallback instead. + return join_types(t.fallback, self.s) + # We set the from_type_type flag to suppress error when a collection of + # concrete class objects gets inferred as their common abstract superclass. + if not ( + (t.is_type_obj() and t.type_object().is_abstract) + or (self.s.is_type_obj() and self.s.type_object().is_abstract) + ): + result.from_type_type = True + return result + else: + s2, t2 = self.s, t + if t2.is_var_arg: + s2, t2 = t2, s2 + if is_subtype(s2, t2): + return t2.copy_modified() + elif is_subtype(t2, s2): + return s2.copy_modified() return join_types(t.fallback, self.s) - return result elif isinstance(self.s, Overloaded): # Switch the order of arguments to that we'll get to visit_overloaded. return join_types(t, self.s) diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index b54dffe836b8..a0762ab78f48 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -3512,6 +3512,56 @@ class Qux(Bar): pass [builtins fixtures/tuple.pyi] +[case testCallableJoinWithDefaults] +from typing import Callable, TypeVar + +T = TypeVar("T") + +def join(t1: T, t2: T) -> T: ... + +def f1() -> None: ... +def f2(i: int = 0) -> None: ... +def f3(i: str = "") -> None: ... + +reveal_type(join(f1, f2)) # N: Revealed type is "def ()" +reveal_type(join(f1, f3)) # N: Revealed type is "def ()" +reveal_type(join(f2, f3)) # N: Revealed type is "builtins.function" # TODO: this could be better +[builtins fixtures/tuple.pyi] + +[case testCallableJoinWithDefaultsMultiple] +from typing import TypeVar +T = TypeVar("T") +def join(t1: T, t2: T, t3: T) -> T: ... + +def f_1(common, a=None): ... +def f_any(*_, **__): ... +def f_3(common, b=None, x=None): ... + +fdict = { + "f_1": f_1, + "f_any": f_any, + "f_3": f_3, +} +reveal_type(fdict) # N: Revealed type is "builtins.dict[builtins.str, def (common: Any, a: Any =) -> Any]" + +reveal_type(join(f_1, f_any, f_3)) # N: Revealed type is "def (common: Any, a: Any =) -> Any" + +[builtins fixtures/tuple.pyi] + +[case testCallableJoinWithType] +from __future__ import annotations +class Exc: ... +class AttributeErr(Exc): + def __init__(self, *args: object) -> None: ... +class FnfErr(Exc): ... + +x = [ + FnfErr, + AttributeErr, +] +reveal_type(x) # N: Revealed type is "builtins.list[builtins.type]" +[builtins fixtures/type.pyi] + [case testDistinctFormatting] from typing import Awaitable, Callable, ParamSpec