Skip to content

Commit 6c1d867

Browse files
authored
Fix ParamSpec inference against TypeVarTuple (#17431)
Fixes #17278 Fixes #17127
1 parent 620e281 commit 6c1d867

File tree

5 files changed

+85
-13
lines changed

5 files changed

+85
-13
lines changed

mypy/constraints.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,11 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
10711071
# (with literal '...').
10721072
if not template.is_ellipsis_args:
10731073
unpack_present = find_unpack_in_list(template.arg_types)
1074-
if unpack_present is not None:
1074+
# When both ParamSpec and TypeVarTuple are present, things become messy
1075+
# quickly. For now, we only allow ParamSpec to "capture" TypeVarTuple,
1076+
# but not vice versa.
1077+
# TODO: infer more from prefixes when possible.
1078+
if unpack_present is not None and not cactual.param_spec():
10751079
# We need to re-normalize args to the form they appear in tuples,
10761080
# for callables we always pack the suffix inside another tuple.
10771081
unpack = template.arg_types[unpack_present]

mypy/expandtype.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,13 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
270270
repl = self.variables.get(t.id, t)
271271
if isinstance(repl, TypeVarTupleType):
272272
return repl
273+
elif isinstance(repl, ProperType) and isinstance(repl, (AnyType, UninhabitedType)):
274+
# Some failed inference scenarios will try to set all type variables to Never.
275+
# Instead of being picky and require all the callers to wrap them,
276+
# do this here instead.
277+
# Note: most cases when this happens are handled in expand unpack below, but
278+
# in rare cases (e.g. ParamSpec containing Unpack star args) it may be skipped.
279+
return t.tuple_fallback.copy_modified(args=[repl])
273280
raise NotImplementedError
274281

275282
def visit_unpack_type(self, t: UnpackType) -> Type:
@@ -348,7 +355,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
348355
# the replacement is ignored.
349356
if isinstance(repl, Parameters):
350357
# We need to expand both the types in the prefix and the ParamSpec itself
351-
return t.copy_modified(
358+
expanded = t.copy_modified(
352359
arg_types=self.expand_types(t.arg_types[:-2]) + repl.arg_types,
353360
arg_kinds=t.arg_kinds[:-2] + repl.arg_kinds,
354361
arg_names=t.arg_names[:-2] + repl.arg_names,
@@ -358,6 +365,11 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
358365
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
359366
variables=[*repl.variables, *t.variables],
360367
)
368+
var_arg = expanded.var_arg()
369+
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
370+
# Sometimes we get new unpacks after expanding ParamSpec.
371+
expanded.normalize_trivial_unpack()
372+
return expanded
361373
elif isinstance(repl, ParamSpecType):
362374
# We're substituting one ParamSpec for another; this can mean that the prefix
363375
# changes, e.g. substitute Concatenate[int, P] in place of Q.

mypy/semanal_typeargs.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from mypy.message_registry import INVALID_PARAM_SPEC_LOCATION, INVALID_PARAM_SPEC_LOCATION_NOTE
1616
from mypy.messages import format_type
1717
from mypy.mixedtraverser import MixedTraverserVisitor
18-
from mypy.nodes import ARG_STAR, Block, ClassDef, Context, FakeInfo, FuncItem, MypyFile
18+
from mypy.nodes import Block, ClassDef, Context, FakeInfo, FuncItem, MypyFile
1919
from mypy.options import Options
2020
from mypy.scope import Scope
2121
from mypy.subtypes import is_same_type, is_subtype
@@ -104,15 +104,7 @@ def visit_tuple_type(self, t: TupleType) -> None:
104104

105105
def visit_callable_type(self, t: CallableType) -> None:
106106
super().visit_callable_type(t)
107-
# Normalize trivial unpack in var args as *args: *tuple[X, ...] -> *args: X
108-
if t.is_var_arg:
109-
star_index = t.arg_kinds.index(ARG_STAR)
110-
star_type = t.arg_types[star_index]
111-
if isinstance(star_type, UnpackType):
112-
p_type = get_proper_type(star_type.type)
113-
if isinstance(p_type, Instance):
114-
assert p_type.type.fullname == "builtins.tuple"
115-
t.arg_types[star_index] = p_type.args[0]
107+
t.normalize_trivial_unpack()
116108

117109
def visit_instance(self, t: Instance) -> None:
118110
super().visit_instance(t)

mypy/types.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -2084,6 +2084,17 @@ def param_spec(self) -> ParamSpecType | None:
20842084
prefix = Parameters(self.arg_types[:-2], self.arg_kinds[:-2], self.arg_names[:-2])
20852085
return arg_type.copy_modified(flavor=ParamSpecFlavor.BARE, prefix=prefix)
20862086

2087+
def normalize_trivial_unpack(self) -> None:
2088+
# Normalize trivial unpack in var args as *args: *tuple[X, ...] -> *args: X in place.
2089+
if self.is_var_arg:
2090+
star_index = self.arg_kinds.index(ARG_STAR)
2091+
star_type = self.arg_types[star_index]
2092+
if isinstance(star_type, UnpackType):
2093+
p_type = get_proper_type(star_type.type)
2094+
if isinstance(p_type, Instance):
2095+
assert p_type.type.fullname == "builtins.tuple"
2096+
self.arg_types[star_index] = p_type.args[0]
2097+
20872098
def with_unpacked_kwargs(self) -> NormalizedCallableType:
20882099
if not self.unpack_kwargs:
20892100
return cast(NormalizedCallableType, self)
@@ -2113,7 +2124,7 @@ def with_normalized_var_args(self) -> Self:
21132124
if not isinstance(unpacked, TupleType):
21142125
# Note that we don't normalize *args: *tuple[X, ...] -> *args: X,
21152126
# this should be done once in semanal_typeargs.py for user-defined types,
2116-
# and we ourselves should never construct such type.
2127+
# and we ourselves rarely construct such type.
21172128
return self
21182129
unpack_index = find_unpack_in_list(unpacked.items)
21192130
if unpack_index == 0 and len(unpacked.items) > 1:

test-data/unit/check-typevar-tuple.test

+53
Original file line numberDiff line numberDiff line change
@@ -2407,3 +2407,56 @@ reveal_type(x) # N: Revealed type is "__main__.C[builtins.str, builtins.int]"
24072407
reveal_type(C(f)) # N: Revealed type is "__main__.C[builtins.str, builtins.int, builtins.int, builtins.int, builtins.int]"
24082408
C[()] # E: At least 1 type argument(s) expected, none given
24092409
[builtins fixtures/tuple.pyi]
2410+
2411+
[case testTypeVarTupleAgainstParamSpecActualSuccess]
2412+
from typing import Generic, TypeVar, TypeVarTuple, Unpack, Callable, Tuple, List
2413+
from typing_extensions import ParamSpec
2414+
2415+
R = TypeVar("R")
2416+
P = ParamSpec("P")
2417+
2418+
class CM(Generic[R]): ...
2419+
def cm(fn: Callable[P, R]) -> Callable[P, CM[R]]: ...
2420+
2421+
Ts = TypeVarTuple("Ts")
2422+
@cm
2423+
def test(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]: ...
2424+
2425+
reveal_type(test) # N: Revealed type is "def [Ts] (*args: Unpack[Ts`-1]) -> __main__.CM[Tuple[Unpack[Ts`-1]]]"
2426+
reveal_type(test(1, 2, 3)) # N: Revealed type is "__main__.CM[Tuple[Literal[1]?, Literal[2]?, Literal[3]?]]"
2427+
[builtins fixtures/tuple.pyi]
2428+
2429+
[case testTypeVarTupleAgainstParamSpecActualFailedNoCrash]
2430+
from typing import Generic, TypeVar, TypeVarTuple, Unpack, Callable, Tuple, List
2431+
from typing_extensions import ParamSpec
2432+
2433+
R = TypeVar("R")
2434+
P = ParamSpec("P")
2435+
2436+
class CM(Generic[R]): ...
2437+
def cm(fn: Callable[P, List[R]]) -> Callable[P, CM[R]]: ...
2438+
2439+
Ts = TypeVarTuple("Ts")
2440+
@cm # E: Argument 1 to "cm" has incompatible type "Callable[[VarArg(Unpack[Ts])], Tuple[Unpack[Ts]]]"; expected "Callable[[VarArg(Never)], List[Never]]"
2441+
def test(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]: ...
2442+
2443+
reveal_type(test) # N: Revealed type is "def (*args: Never) -> __main__.CM[Never]"
2444+
[builtins fixtures/tuple.pyi]
2445+
2446+
[case testTypeVarTupleAgainstParamSpecActualPrefix]
2447+
from typing import Generic, TypeVar, TypeVarTuple, Unpack, Callable, Tuple, List
2448+
from typing_extensions import ParamSpec, Concatenate
2449+
2450+
R = TypeVar("R")
2451+
P = ParamSpec("P")
2452+
T = TypeVar("T")
2453+
2454+
class CM(Generic[R]): ...
2455+
def cm(fn: Callable[Concatenate[T, P], R]) -> Callable[Concatenate[List[T], P], CM[R]]: ...
2456+
2457+
Ts = TypeVarTuple("Ts")
2458+
@cm
2459+
def test(x: T, *args: Unpack[Ts]) -> Tuple[T, Unpack[Ts]]: ...
2460+
2461+
reveal_type(test) # N: Revealed type is "def [T, Ts] (builtins.list[T`2], *args: Unpack[Ts`-2]) -> __main__.CM[Tuple[T`2, Unpack[Ts`-2]]]"
2462+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)