Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - Allow infer types in expression type argument positions #22368

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/compiler/binder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,8 @@ namespace ts {
return ContainerFlags.IsContainer | ContainerFlags.HasLocals;

case SyntaxKind.ConditionalType:
case SyntaxKind.CallExpression:
case SyntaxKind.NewExpression:
return ContainerFlags.IsInferenceContainer;

case SyntaxKind.SourceFile:
Expand Down
89 changes: 78 additions & 11 deletions src/compiler/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3021,6 +3021,11 @@ namespace ts {
if (type.flags & TypeFlags.Substitution) {
return typeToTypeNodeHelper((<SubstitutionType>type).typeParameter, context);
}
if (type.flags & TypeFlags.InferType) {
// Infer types only parse as identifiers, so the target should always be a TypeParameter that becomes a TypeReferenceNode
const ref = typeToTypeNodeHelper((<InferType>type).target, context) as TypeReferenceNode;
return createInferTypeNode(createTypeParameterDeclaration(ref.typeName as Identifier));
}

Debug.fail("Should be unreachable.");

Expand Down Expand Up @@ -3517,7 +3522,7 @@ namespace ts {
const params = getTypeParametersOfClassOrInterface(
parentSymbol.flags & SymbolFlags.Alias ? resolveAlias(parentSymbol) : parentSymbol
);
typeParameterNodes = mapToTypeNodes(map(params, (nextSymbol as TransientSymbol).mapper), context);
typeParameterNodes = mapToTypeNodes(mapIndexless(params, (nextSymbol as TransientSymbol).mapper), context);
}
else {
typeParameterNodes = typeParametersToTypeParameterDeclarations(symbol, context);
Expand Down Expand Up @@ -4736,12 +4741,14 @@ namespace ts {
case SyntaxKind.JSDocTemplateTag:
case SyntaxKind.MappedType:
case SyntaxKind.ConditionalType:
case SyntaxKind.CallExpression:
case SyntaxKind.NewExpression:
const outerTypeParameters = getOuterTypeParameters(node, includeThisTypes);
if (node.kind === SyntaxKind.MappedType) {
return append(outerTypeParameters, getDeclaredTypeOfTypeParameter(getSymbolOfNode((<MappedTypeNode>node).typeParameter)));
}
else if (node.kind === SyntaxKind.ConditionalType) {
return concatenate(outerTypeParameters, getInferTypeParameters(<ConditionalTypeNode>node));
else if (node.kind === SyntaxKind.ConditionalType || node.kind === SyntaxKind.NewExpression || node.kind === SyntaxKind.CallExpression) {
return concatenate(outerTypeParameters, getInferTypeParameters(<ConditionalTypeNode | CallLikeExpression>node));
}
const outerAndOwnTypeParameters = appendTypeParameters(outerTypeParameters, getEffectiveTypeParameterDeclarations(<DeclarationWithTypeParameters>node) || emptyArray);
const thisType = includeThisTypes &&
Expand Down Expand Up @@ -8334,7 +8341,7 @@ namespace ts {
return type.resolvedFalseType || (type.resolvedFalseType = instantiateType(type.root.falseType, type.mapper));
}

function getInferTypeParameters(node: ConditionalTypeNode): TypeParameter[] {
function getInferTypeParameters(node: ConditionalTypeNode | CallLikeExpression): TypeParameter[] {
let result: TypeParameter[];
if (node.locals) {
node.locals.forEach(symbol => {
Expand Down Expand Up @@ -8375,10 +8382,16 @@ namespace ts {
return links.resolvedType;
}

function createInferType(target: TypeParameter): InferType {
const type = createType(TypeFlags.InferType) as InferType;
type.target = target;
return type;
}

function getTypeFromInferTypeNode(node: InferTypeNode): Type {
const links = getNodeLinks(node);
if (!links.resolvedType) {
links.resolvedType = getDeclaredTypeOfTypeParameter(getSymbolOfNode(node.typeParameter));
links.resolvedType = createInferType(getDeclaredTypeOfTypeParameter(getSymbolOfNode(node.typeParameter)));
}
return links.resolvedType;
}
Expand Down Expand Up @@ -8882,7 +8895,7 @@ namespace ts {
// mapper to the type parameters to produce the effective list of type arguments, and compute the
// instantiation cache key from the type IDs of the type arguments.
const combinedMapper = type.objectFlags & ObjectFlags.Instantiated ? combineTypeMappers(type.mapper, mapper) : mapper;
const typeArguments = map(typeParameters, combinedMapper);
const typeArguments = mapIndexless(typeParameters, combinedMapper);
const id = getTypeListId(typeArguments);
let result = links.instantiations.get(id);
if (!result) {
Expand Down Expand Up @@ -8965,7 +8978,7 @@ namespace ts {
// We are instantiating a conditional type that has one or more type parameters in scope. Apply the
// mapper to the type parameters to produce the effective list of type arguments, and compute the
// instantiation cache key from the type IDs of the type arguments.
const typeArguments = map(root.outerTypeParameters, mapper);
const typeArguments = mapIndexless(root.outerTypeParameters, mapper);
const id = getTypeListId(typeArguments);
let result = root.instantiations.get(id);
if (!result) {
Expand Down Expand Up @@ -9036,6 +9049,15 @@ namespace ts {
if (type.flags & TypeFlags.Substitution) {
return mapper((<SubstitutionType>type).typeParameter);
}
if (type.flags & TypeFlags.InferType) {
// Fresh infer types are not *actually* type parameters, but look like one; this gives mappers the opportunity
// to handle one directly (as is done for partial inference), before it gets mapped to its target.
const result = mapper(<InferType>type);
if (result !== type) {
return result;
}
return instantiateType((<InferType>type).target, mapper);
}
}
return type;
}
Expand Down Expand Up @@ -9642,9 +9664,15 @@ namespace ts {
if (source.flags & TypeFlags.Substitution) {
source = relation === definitelyAssignableRelation ? (<SubstitutionType>source).typeParameter : (<SubstitutionType>source).substitute;
}
if (source.flags & TypeFlags.InferType) {
source = (<InferType>source).target;
}
if (target.flags & TypeFlags.Substitution) {
target = (<SubstitutionType>target).typeParameter;
}
if (target.flags & TypeFlags.InferType) {
target = (<InferType>target).target;
}

// both types are the same - covers 'they are the same primitive type or both are Any' or the same type parameter cases
if (source === target) return Ternary.True;
Expand Down Expand Up @@ -11587,6 +11615,12 @@ namespace ts {
if (!couldContainTypeVariables(target)) {
return;
}
if (source.flags & TypeFlags.InferType) {
source = (source as InferType).target;
}
if (target.flags & TypeFlags.InferType) {
target = (target as InferType).target;
}
if (source.flags & TypeFlags.Any) {
// We are inferring from an 'any' type. We want to infer this type for every type parameter
// referenced in the target type, so we record it as the propagation type and infer from the
Expand Down Expand Up @@ -17529,10 +17563,35 @@ namespace ts {
candidate = originalCandidate;
if (candidate.typeParameters) {
let typeArgumentTypes: Type[];
const isJavascript = isInJavaScriptFile(candidate.declaration);
if (typeArguments) {
const typeArgumentResult = checkTypeArguments(candidate, typeArguments, /*reportErrors*/ false);
if (typeArgumentResult) {
typeArgumentTypes = typeArgumentResult;
if (node.locals) {
// Call has `infer` arguments that still need to be inferred and instantiated
const inferParams = getInferTypeParameters(node);
// Mapper replaces references to infered type parameters with emptyObjectType
// Causing the original location to be the _only_ inference site
const preprocessMapper = (p: TypeParameter) => {
// Fresh infer types are not *actually* type parameters, but look like one
if (p.flags & TypeFlags.InferType) {
return p.target; // By doing the replacement here, we cause this mapper to be not-called with the target
}
if (contains(inferParams, p)) {
return emptyObjectType;
}
return p;
};
const resultsWithNonInferInferredVarsDefaulted = map(typeArgumentResult, t => instantiateType(t, preprocessMapper));
const partialCandidate = getSignatureInstantiation(candidate, resultsWithNonInferInferredVarsDefaulted, isJavascript);
const context = createInferenceContext(inferParams, partialCandidate, InferenceFlags.None);
const inferences = inferTypeArguments(node, partialCandidate, args, excludeArgument, context);
const mapper = createTypeMapper(inferParams, inferences);
typeArgumentTypes = map(typeArgumentResult, t => instantiateType(t, mapper));
}
else {
typeArgumentTypes = typeArgumentResult;
}
}
else {
candidateForTypeArgumentError = originalCandidate;
Expand All @@ -17542,7 +17601,6 @@ namespace ts {
else {
typeArgumentTypes = inferTypeArguments(node, candidate, args, excludeArgument, inferenceContext);
}
const isJavascript = isInJavaScriptFile(candidate.declaration);
candidate = getSignatureInstantiation(candidate, typeArgumentTypes, isJavascript);
}
if (!checkApplicableSignature(node, args, candidate, relation, excludeArgument, /*reportErrors*/ false)) {
Expand Down Expand Up @@ -20544,9 +20602,18 @@ namespace ts {
forEachChild(node, checkSourceElement);
}

function isConditionalTypeExtendsClause(n: Node) {
return n.parent && n.parent.kind === SyntaxKind.ConditionalType && (<ConditionalTypeNode>n.parent).extendsType === n;
}

function isCallOrNewExpressionTypeArgument(n: Node) {
return n.parent && (n.parent.kind === SyntaxKind.CallExpression || n.parent.kind === SyntaxKind.NewExpression)
&& contains((<CallExpression | NewExpression>n.parent).typeArguments, n);
}

function checkInferType(node: InferTypeNode) {
if (!findAncestor(node, n => n.parent && n.parent.kind === SyntaxKind.ConditionalType && (<ConditionalTypeNode>n.parent).extendsType === n)) {
grammarErrorOnNode(node, Diagnostics.infer_declarations_are_only_permitted_in_the_extends_clause_of_a_conditional_type);
if (!findAncestor(node, n => isConditionalTypeExtendsClause(n) || isCallOrNewExpressionTypeArgument(n))) {
grammarErrorOnNode(node, Diagnostics.infer_declarations_are_only_permitted_in_the_extends_clause_of_a_conditional_type_or_in_call_or_new_expression_type_argument_lists);
}
checkSourceElement(node.typeParameter);
}
Expand Down
11 changes: 11 additions & 0 deletions src/compiler/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,17 @@ namespace ts {
array.length = 0;
}

export function mapIndexless<T, U>(array: ReadonlyArray<T>, f: (x: T) => U): U[] {
let result: U[];
if (array) {
result = [];
for (const elem of array) {
result.push(f(elem));
}
}
return result;
}

export function map<T, U>(array: ReadonlyArray<T>, f: (x: T, i: number) => U): U[] {
let result: U[];
if (array) {
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/diagnosticMessages.json
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@
"category": "Error",
"code": 1337
},
"'infer' declarations are only permitted in the 'extends' clause of a conditional type.": {
"'infer' declarations are only permitted in the 'extends' clause of a conditional type or in call or new expression type argument lists.": {
"category": "Error",
"code": 1338
},
Expand Down
8 changes: 7 additions & 1 deletion src/compiler/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3538,6 +3538,7 @@ namespace ts {
/* @internal */
ContainsAnyFunctionType = 1 << 26, // Type is or contains the anyFunctionType
NonPrimitive = 1 << 27, // intrinsic object type
InferType = 1 << 28, // A type whose concrete value upon instantiation will be inferred at a given site
/* @internal */
GenericMappedType = 1 << 29, // Flag used by maybeTypeOfKind

Expand All @@ -3562,7 +3563,7 @@ namespace ts {
ESSymbolLike = ESSymbol | UniqueESSymbol,
UnionOrIntersection = Union | Intersection,
StructuredType = Object | Union | Intersection,
TypeVariable = TypeParameter | IndexedAccess,
TypeVariable = TypeParameter | IndexedAccess | InferType,
InstantiableNonPrimitive = TypeVariable | Conditional | Substitution,
InstantiablePrimitive = Index,
Instantiable = InstantiableNonPrimitive | InstantiablePrimitive,
Expand Down Expand Up @@ -3817,6 +3818,11 @@ namespace ts {
resolvedDefaultType?: Type;
}

// Infer Types (TypeFlags.InferType)
export interface InferType extends Type {
target: TypeParameter;
}

// Indexed access types (TypeFlags.IndexedAccess)
// Possible forms are T[xxx], xxx[T], or xxx[keyof T], where T is a type variable
export interface IndexedAccessType extends InstantiableType {
Expand Down
14 changes: 9 additions & 5 deletions tests/baselines/reference/api/tsserverlibrary.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,7 @@ declare namespace ts {
Conditional = 2097152,
Substitution = 4194304,
NonPrimitive = 134217728,
InferType = 268435456,
Literal = 224,
Unit = 13536,
StringOrNumberLiteral = 96,
Expand All @@ -2085,12 +2086,12 @@ declare namespace ts {
ESSymbolLike = 1536,
UnionOrIntersection = 393216,
StructuredType = 458752,
TypeVariable = 1081344,
InstantiableNonPrimitive = 7372800,
TypeVariable = 269516800,
InstantiableNonPrimitive = 275808256,
InstantiablePrimitive = 524288,
Instantiable = 7897088,
StructuredOrInstantiable = 8355840,
Narrowable = 142575359,
Instantiable = 276332544,
StructuredOrInstantiable = 276791296,
Narrowable = 411010815,
NotUnionOrUnit = 134283777,
}
type DestructuringPattern = BindingPattern | ObjectLiteralExpression | ArrayLiteralExpression;
Expand Down Expand Up @@ -2184,6 +2185,9 @@ declare namespace ts {
}
interface TypeParameter extends InstantiableType {
}
interface InferType extends Type {
target: TypeParameter;
}
interface IndexedAccessType extends InstantiableType {
objectType: Type;
indexType: Type;
Expand Down
14 changes: 9 additions & 5 deletions tests/baselines/reference/api/typescript.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,7 @@ declare namespace ts {
Conditional = 2097152,
Substitution = 4194304,
NonPrimitive = 134217728,
InferType = 268435456,
Literal = 224,
Unit = 13536,
StringOrNumberLiteral = 96,
Expand All @@ -2085,12 +2086,12 @@ declare namespace ts {
ESSymbolLike = 1536,
UnionOrIntersection = 393216,
StructuredType = 458752,
TypeVariable = 1081344,
InstantiableNonPrimitive = 7372800,
TypeVariable = 269516800,
InstantiableNonPrimitive = 275808256,
InstantiablePrimitive = 524288,
Instantiable = 7897088,
StructuredOrInstantiable = 8355840,
Narrowable = 142575359,
Instantiable = 276332544,
StructuredOrInstantiable = 276791296,
Narrowable = 411010815,
NotUnionOrUnit = 134283777,
}
type DestructuringPattern = BindingPattern | ObjectLiteralExpression | ArrayLiteralExpression;
Expand Down Expand Up @@ -2184,6 +2185,9 @@ declare namespace ts {
}
interface TypeParameter extends InstantiableType {
}
interface InferType extends Type {
target: TypeParameter;
}
interface IndexedAccessType extends InstantiableType {
objectType: Type;
indexType: Type;
Expand Down
Loading