-
-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for functools.partial #16939
Changes from 11 commits
5b56460
ff4914f
c0084eb
a847234
2a5f3f1
325cae4
1ee376d
b7ca434
03b397e
d2886ac
65e356a
c7f6d78
0b46081
1533351
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,22 @@ | |
|
||
from typing import Final, NamedTuple | ||
|
||
import mypy.checker | ||
import mypy.plugin | ||
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var | ||
from mypy.argmap import map_actuals_to_formals | ||
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var | ||
from mypy.plugins.common import add_method_to_class | ||
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type | ||
from mypy.types import ( | ||
AnyType, | ||
CallableType, | ||
Instance, | ||
Overloaded, | ||
Type, | ||
TypeOfAny, | ||
UnboundType, | ||
UninhabitedType, | ||
get_proper_type, | ||
) | ||
|
||
functools_total_ordering_makers: Final = {"functools.total_ordering"} | ||
|
||
|
@@ -102,3 +114,131 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo | | |
comparison_methods[name] = None | ||
|
||
return comparison_methods | ||
|
||
|
||
def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: | ||
"""Infer a more precise return type for functools.partial""" | ||
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals | ||
return ctx.default_return_type | ||
if len(ctx.arg_types) != 3: # fn, *args, **kwargs | ||
return ctx.default_return_type | ||
if len(ctx.arg_types[0]) != 1: | ||
return ctx.default_return_type | ||
|
||
if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded): | ||
# TODO: handle overloads, just fall back to whatever the non-plugin code does | ||
return ctx.default_return_type | ||
fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type) | ||
if fn_type is None: | ||
return ctx.default_return_type | ||
|
||
defaulted = fn_type.copy_modified( | ||
arg_kinds=[ | ||
( | ||
ArgKind.ARG_OPT | ||
if k == ArgKind.ARG_POS | ||
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k) | ||
) | ||
for k in fn_type.arg_kinds | ||
] | ||
) | ||
if defaulted.line < 0: | ||
# Make up a line number if we don't have one | ||
defaulted.set_line(ctx.default_return_type) | ||
|
||
actual_args = [a for param in ctx.args[1:] for a in param] | ||
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param] | ||
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param] | ||
actual_types = [a for param in ctx.arg_types[1:] for a in param] | ||
|
||
_, bound = ctx.api.expr_checker.check_call( | ||
callee=defaulted, | ||
args=actual_args, | ||
arg_kinds=actual_arg_kinds, | ||
arg_names=actual_arg_names, | ||
context=defaulted, | ||
) | ||
bound = get_proper_type(bound) | ||
if not isinstance(bound, CallableType): | ||
return ctx.default_return_type | ||
|
||
formal_to_actual = map_actuals_to_formals( | ||
actual_kinds=actual_arg_kinds, | ||
actual_names=actual_arg_names, | ||
formal_kinds=fn_type.arg_kinds, | ||
formal_names=fn_type.arg_names, | ||
actual_arg_type=lambda i: actual_types[i], | ||
) | ||
|
||
partial_kinds = [] | ||
partial_types = [] | ||
partial_names = [] | ||
# We need to fully apply any positional arguments (they cannot be respecified) | ||
# However, keyword arguments can be respecified, so just give them a default | ||
for i, actuals in enumerate(formal_to_actual): | ||
if len(bound.arg_types) == len(fn_type.arg_types): | ||
arg_type = bound.arg_types[i] | ||
if isinstance(get_proper_type(arg_type), UninhabitedType): | ||
arg_type = fn_type.arg_types[i] # bit of a hack | ||
else: | ||
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't | ||
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple | ||
arg_type = fn_type.arg_types[i] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to above. |
||
|
||
if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2): | ||
partial_kinds.append(fn_type.arg_kinds[i]) | ||
partial_types.append(arg_type) | ||
partial_names.append(fn_type.arg_names[i]) | ||
elif actuals: | ||
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals): | ||
continue | ||
kind = actual_arg_kinds[actuals[0]] | ||
if kind == ArgKind.ARG_NAMED: | ||
kind = ArgKind.ARG_NAMED_OPT | ||
partial_kinds.append(kind) | ||
partial_types.append(arg_type) | ||
partial_names.append(fn_type.arg_names[i]) | ||
|
||
ret_type = bound.ret_type | ||
if isinstance(get_proper_type(ret_type), UninhabitedType): | ||
ret_type = fn_type.ret_type # same kind of hack as above | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to above -- it seems that this might leak type variables. If that is the case, it would probably be better to fall back to the default return type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this doesn't leak type variables, because But let me know if I'm off the mark! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you are right, and this is good already. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like this still can leak type variables in more tricky cases. For example:
Btw when looking at whether this is correctly supported by current polymorphic inference I found that it is supported, but there is a spurious error at the first call site:
This is caused by the first bullet point in #15907. I guess I will need to finally fix this, I was procrastinating for too long (but to be fair fixing this would touch some really old and hacky code). |
||
|
||
partially_applied = fn_type.copy_modified( | ||
arg_types=partial_types, | ||
arg_kinds=partial_kinds, | ||
arg_names=partial_names, | ||
ret_type=ret_type, | ||
) | ||
|
||
ret = ctx.api.named_generic_type("functools.partial", [ret_type]) | ||
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Now if we import I'm not sure what is the best way to fix this. Probably the simplest option would be to serialize Also it would be good to have an incremental mode test case. It looks like mypy daemon doesn't keep track of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh well. IIRC I added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also now that you added some persistence for |
||
return ret | ||
|
||
|
||
def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: | ||
"""Infer a more precise return type for functools.partial.__call__.""" | ||
if ( | ||
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals | ||
or not isinstance(ctx.type, Instance) | ||
or ctx.type.type.fullname != "functools.partial" | ||
or not ctx.type.extra_attrs | ||
or "__mypy_partial" not in ctx.type.extra_attrs.attrs | ||
): | ||
return ctx.default_return_type | ||
|
||
partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"] | ||
if len(ctx.arg_types) != 2: # *args, **kwargs | ||
return ctx.default_return_type | ||
|
||
args = [a for param in ctx.args for a in param] | ||
arg_kinds = [a for param in ctx.arg_kinds for a in param] | ||
arg_names = [a for param in ctx.arg_names for a in param] | ||
|
||
result = ctx.api.expr_checker.check_call( | ||
callee=partial_type, | ||
args=args, | ||
arg_kinds=arg_kinds, | ||
arg_names=arg_names, | ||
context=ctx.context, | ||
) | ||
return result[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm this could refer to a type variable type, so I wonder if this can leak type variables, in the target function is a generic one, and one of the provided arguments can be used to bind the type variable.