Skip to content

Commit f291d54

Browse files
committed
Honor return type of __new__
This basically follows the approach Jukka laid out in #1020 four years ago: * If the return type is Any, ignore that and keep the class type as the return type * Otherwise respect `__new__`'s return type * Produce an error if the return type is not a subtype of the class. The main motivation for me in implementing this is to support overloading `__new__` in order to select type variable arguments, which will be useful for subprocess.Popen. Fixes #1020.
1 parent e479b6d commit f291d54

11 files changed

+190
-46
lines changed

mypy/checker.py

+25
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,10 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
803803
self.fail(message_registry.MUST_HAVE_NONE_RETURN_TYPE.format(fdef.name()),
804804
item)
805805

806+
# Check validity of __new__ signature
807+
if fdef.info and fdef.name() == '__new__':
808+
self.check___new___signature(fdef, typ)
809+
806810
self.check_for_missing_annotations(fdef)
807811
if self.options.disallow_any_unimported:
808812
if fdef.type and isinstance(fdef.type, CallableType):
@@ -1015,6 +1019,27 @@ def is_unannotated_any(t: Type) -> bool:
10151019
if any(is_unannotated_any(t) for t in fdef.type.arg_types):
10161020
self.fail(message_registry.ARGUMENT_TYPE_EXPECTED, fdef)
10171021

1022+
def check___new___signature(self, fdef: FuncDef, typ: CallableType) -> None:
1023+
self_type = fill_typevars_with_any(fdef.info)
1024+
bound_type = bind_self(typ, self_type, is_classmethod=True)
1025+
# Check that __new__ (after binding cls) returns an instance
1026+
# type (or any)
1027+
if not isinstance(bound_type.ret_type, (AnyType, Instance, TupleType)):
1028+
self.fail(
1029+
message_registry.NON_INSTANCE_NEW_TYPE.format(
1030+
self.msg.format(bound_type.ret_type)),
1031+
fdef)
1032+
else:
1033+
# And that it returns a subtype of the class
1034+
self.check_subtype(
1035+
bound_type.ret_type,
1036+
self_type,
1037+
fdef,
1038+
message_registry.INVALID_NEW_TYPE,
1039+
'returns',
1040+
'but must return a subtype of'
1041+
)
1042+
10181043
def is_trivial_body(self, block: Block) -> bool:
10191044
"""Returns 'true' if the given body is "trivial" -- if it contains just a "pass",
10201045
"..." (ellipsis), or "raise NotImplementedError()". A trivial body may also

mypy/checkmember.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -803,8 +803,10 @@ def type_object_type(info: TypeInfo, builtin_type: Callable[[str], Instance]) ->
803803
fallback = info.metaclass_type or builtin_type('builtins.type')
804804
if init_index < new_index:
805805
method = init_method.node # type: Union[FuncBase, Decorator]
806+
is_new = False
806807
elif init_index > new_index:
807808
method = new_method.node
809+
is_new = True
808810
else:
809811
if init_method.node.info.fullname() == 'builtins.object':
810812
# Both are defined by object. But if we've got a bogus
@@ -823,14 +825,15 @@ def type_object_type(info: TypeInfo, builtin_type: Callable[[str], Instance]) ->
823825
# is the right thing, but __new__ caused problems with
824826
# typeshed (#5647).
825827
method = init_method.node
828+
is_new = False
826829
# Construct callable type based on signature of __init__. Adjust
827830
# return type and insert type arguments.
828831
if isinstance(method, FuncBase):
829832
t = function_type(method, fallback)
830833
else:
831834
assert isinstance(method.type, FunctionLike) # is_valid_constructor() ensures this
832835
t = method.type
833-
return type_object_type_from_function(t, info, method.info, fallback)
836+
return type_object_type_from_function(t, info, method.info, fallback, is_new)
834837

835838

836839
def is_valid_constructor(n: Optional[SymbolNode]) -> bool:
@@ -849,7 +852,8 @@ def is_valid_constructor(n: Optional[SymbolNode]) -> bool:
849852
def type_object_type_from_function(signature: FunctionLike,
850853
info: TypeInfo,
851854
def_info: TypeInfo,
852-
fallback: Instance) -> FunctionLike:
855+
fallback: Instance,
856+
is_new: bool) -> FunctionLike:
853857
# The __init__ method might come from a generic superclass
854858
# (init_or_new.info) with type variables that do not map
855859
# identically to the type variables of the class being constructed
@@ -859,7 +863,7 @@ def type_object_type_from_function(signature: FunctionLike,
859863
# class B(A[List[T]], Generic[T]): pass
860864
#
861865
# We need to first map B's __init__ to the type (List[T]) -> None.
862-
signature = bind_self(signature)
866+
signature = bind_self(signature, original_type=fill_typevars(info), is_classmethod=is_new)
863867
signature = cast(FunctionLike,
864868
map_type_from_supertype(signature, info, def_info))
865869
special_sig = None # type: Optional[str]
@@ -868,25 +872,32 @@ def type_object_type_from_function(signature: FunctionLike,
868872
special_sig = 'dict'
869873

870874
if isinstance(signature, CallableType):
871-
return class_callable(signature, info, fallback, special_sig)
875+
return class_callable(signature, info, fallback, special_sig, is_new)
872876
else:
873877
# Overloaded __init__/__new__.
874878
assert isinstance(signature, Overloaded)
875879
items = [] # type: List[CallableType]
876880
for item in signature.items():
877-
items.append(class_callable(item, info, fallback, special_sig))
881+
items.append(class_callable(item, info, fallback, special_sig, is_new))
878882
return Overloaded(items)
879883

880884

881885
def class_callable(init_type: CallableType, info: TypeInfo, type_type: Instance,
882-
special_sig: Optional[str]) -> CallableType:
886+
special_sig: Optional[str],
887+
is_new: bool = False) -> CallableType:
883888
"""Create a type object type based on the signature of __init__."""
884889
variables = [] # type: List[TypeVarDef]
885890
variables.extend(info.defn.type_vars)
886891
variables.extend(init_type.variables)
887892

893+
is_new = True
894+
if is_new and isinstance(init_type.ret_type, (Instance, TupleType)):
895+
ret_type = init_type.ret_type # type: Type
896+
else:
897+
ret_type = fill_typevars(info)
898+
888899
callable_type = init_type.copy_modified(
889-
ret_type=fill_typevars(info), fallback=type_type, name=None, variables=variables,
900+
ret_type=ret_type, fallback=type_type, name=None, variables=variables,
890901
special_sig=special_sig)
891902
c = callable_type.with_name(info.name())
892903
return c

mypy/interpreted_plugin.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class InterpretedPlugin:
2222
that proxies to this interpreted version.
2323
"""
2424

25-
def __new__(cls, *args: Any, **kwargs: Any) -> 'mypy.plugin.Plugin':
25+
# ... mypy doesn't like these shenanigans so we have to type ignore it!
26+
def __new__(cls, *args: Any, **kwargs: Any) -> 'mypy.plugin.Plugin': # type: ignore
2627
from mypy.plugin import WrapperPlugin
2728
plugin = object.__new__(cls)
2829
plugin.__init__(*args, **kwargs)

mypy/message_registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
INVALID_SLICE_INDEX = 'Slice index must be an integer or None' # type: Final
5252
CANNOT_INFER_LAMBDA_TYPE = 'Cannot infer type of lambda' # type: Final
5353
CANNOT_ACCESS_INIT = 'Cannot access "__init__" directly' # type: Final
54+
NON_INSTANCE_NEW_TYPE = '"__new__" must return a class instance (got {})' # type: Final
55+
INVALID_NEW_TYPE = 'Incompatible return type for "__new__"' # type: Final
5456
BAD_CONSTRUCTOR_TYPE = 'Unsupported decorated constructor type' # type: Final
5557
CANNOT_ASSIGN_TO_METHOD = 'Cannot assign to a method' # type: Final
5658
CANNOT_ASSIGN_TO_TYPE = 'Cannot assign to a type' # type: Final

test-data/unit/check-class-namedtuple.test

+1-1
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ class XMethBad(NamedTuple):
598598
class MagicalFields(NamedTuple):
599599
x: int
600600
def __slots__(self) -> None: pass # E: Cannot overwrite NamedTuple attribute "__slots__"
601-
def __new__(cls) -> None: pass # E: Cannot overwrite NamedTuple attribute "__new__"
601+
def __new__(cls) -> MagicalFields: pass # E: Cannot overwrite NamedTuple attribute "__new__"
602602
def _source(self) -> int: pass # E: Cannot overwrite NamedTuple attribute "_source"
603603
__annotations__ = {'x': float} # E: NamedTuple field name cannot start with an underscore: __annotations__ \
604604
# E: Invalid statement in NamedTuple definition; expected "field_name: field_type [= default]" \

test-data/unit/check-classes.test

+108-6
Original file line numberDiff line numberDiff line change
@@ -344,12 +344,12 @@ main:6: error: Return type "A" of "f" incompatible with return type "None" in su
344344

345345
[case testOverride__new__WithDifferentSignature]
346346
class A:
347-
def __new__(cls, x: int) -> str:
348-
return ''
347+
def __new__(cls, x: int) -> A:
348+
pass
349349

350350
class B(A):
351-
def __new__(cls) -> int:
352-
return 1
351+
def __new__(cls) -> B:
352+
pass
353353

354354
[case testOverride__new__AndCallObject]
355355
from typing import TypeVar, Generic
@@ -5363,8 +5363,8 @@ class A:
53635363
pass
53645364

53655365
class B(A):
5366-
def __new__(cls) -> int:
5367-
return 10
5366+
def __new__(cls) -> B:
5367+
pass
53685368

53695369
B()
53705370

@@ -5975,3 +5975,105 @@ class E(C):
59755975
reveal_type(self.x) # N: Revealed type is 'builtins.int'
59765976

59775977
[targets __main__, __main__, __main__.D.g, __main__.D.f, __main__.C.__init__, __main__.E.g, __main__.E.f]
5978+
5979+
[case testNewReturnType1]
5980+
class A:
5981+
def __new__(cls) -> B:
5982+
pass
5983+
5984+
class B(A): pass
5985+
5986+
reveal_type(A()) # N: Revealed type is '__main__.B'
5987+
reveal_type(B()) # N: Revealed type is '__main__.B'
5988+
5989+
[case testNewReturnType2]
5990+
from typing import Any
5991+
5992+
# make sure that __new__ method that return Any are ignored when
5993+
# determining the return type
5994+
class A:
5995+
def __new__(cls):
5996+
pass
5997+
5998+
class B:
5999+
def __new__(cls) -> Any:
6000+
pass
6001+
6002+
reveal_type(A()) # N: Revealed type is '__main__.A'
6003+
reveal_type(B()) # N: Revealed type is '__main__.B'
6004+
6005+
[case testNewReturnType3]
6006+
6007+
# Check for invalid __new__ typing
6008+
6009+
class A:
6010+
def __new__(cls) -> int: # E: Incompatible return type for "__new__" (returns "int", but must return a subtype of "A")
6011+
pass
6012+
6013+
reveal_type(A()) # N: Revealed type is 'builtins.int'
6014+
6015+
[case testNewReturnType4]
6016+
from typing import TypeVar, Type
6017+
6018+
# Check for __new__ using type vars
6019+
6020+
TX = TypeVar('TX', bound='X')
6021+
class X:
6022+
def __new__(lol: Type[TX], x: int) -> TX:
6023+
pass
6024+
class Y(X): pass
6025+
6026+
reveal_type(X(20)) # N: Revealed type is '__main__.X*'
6027+
reveal_type(Y(20)) # N: Revealed type is '__main__.Y*'
6028+
6029+
[case testNewReturnType5]
6030+
from typing import Any, TypeVar, Generic, overload
6031+
6032+
T = TypeVar('T')
6033+
class O(Generic[T]):
6034+
@overload
6035+
def __new__(cls) -> O[int]:
6036+
pass
6037+
@overload
6038+
def __new__(cls, x: int) -> O[str]:
6039+
pass
6040+
def __new__(cls, x: int = 0) -> O[Any]:
6041+
pass
6042+
6043+
reveal_type(O()) # N: Revealed type is '__main__.O[builtins.int]'
6044+
reveal_type(O(10)) # N: Revealed type is '__main__.O[builtins.str]'
6045+
6046+
[case testNewReturnType6]
6047+
from typing import Tuple, Optional
6048+
6049+
# Check for some cases that aren't allowed
6050+
6051+
class X:
6052+
def __new__(cls) -> Optional[Y]: # E: "__new__" must return a class instance (got "Optional[Y]")
6053+
pass
6054+
class Y:
6055+
def __new__(cls) -> Optional[int]: # E: "__new__" must return a class instance (got "Optional[int]")
6056+
pass
6057+
6058+
6059+
[case testNewReturnType7]
6060+
from typing import NamedTuple
6061+
6062+
# ... test __new__ returning tuple type
6063+
class A:
6064+
def __new__(cls) -> 'B':
6065+
pass
6066+
6067+
N = NamedTuple('N', [('x', int)])
6068+
class B(A, N): pass
6069+
6070+
reveal_type(A()) # N: Revealed type is 'Tuple[builtins.int, fallback=__main__.B]'
6071+
6072+
[case testNewReturnType8]
6073+
from typing import TypeVar, Any
6074+
6075+
# test type var from a different argument
6076+
TX = TypeVar('TX', bound='X')
6077+
class X:
6078+
def __new__(cls, x: TX) -> TX: # E: "__new__" must return a class instance (got "TX")
6079+
pass

0 commit comments

Comments
 (0)