-
Notifications
You must be signed in to change notification settings - Fork 4.1k
/
Copy pathCSharpDeclareAsNullableCodeFixProvider.cs
344 lines (295 loc) · 15.3 KB
/
CSharpDeclareAsNullableCodeFixProvider.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.DeclareAsNullable;
[ExportCodeFixProvider(LanguageNames.CSharp, Name = PredefinedCodeFixProviderNames.DeclareAsNullable), Shared]
internal class CSharpDeclareAsNullableCodeFixProvider : SyntaxEditorBasedCodeFixProvider
{
// We want to distinguish different situations:
// 1. local null assignments: `return null;`, `local = null;`, `parameter = null;` (high confidence that the null is introduced deliberately and the API should be updated)
// 2. invocation with null: `M(null);`, or assigning null to field or property (test code might do this even though the API should remain not-nullable, so FixAll should be invoked with care)
// 3. conditional: `return x?.ToString();`
private const string AssigningNullLiteralLocallyEquivalenceKey = nameof(AssigningNullLiteralLocallyEquivalenceKey);
private const string AssigningNullLiteralRemotelyEquivalenceKey = nameof(AssigningNullLiteralRemotelyEquivalenceKey);
private const string ConditionalOperatorEquivalenceKey = nameof(ConditionalOperatorEquivalenceKey);
[ImportingConstructor]
[SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
public CSharpDeclareAsNullableCodeFixProvider()
{
}
// warning CS8603: Possible null reference return.
// warning CS8600: Converting null literal or possible null value to non-nullable type.
// warning CS8625: Cannot convert null literal to non-nullable reference type.
// warning CS8618: Non-nullable property is uninitialized
public sealed override ImmutableArray<string> FixableDiagnosticIds => ["CS8603", "CS8600", "CS8625", "CS8618"];
public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
var cancellationToken = context.CancellationToken;
var model = await context.Document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var node = context.Diagnostics.First().Location.FindNode(getInnermostNodeForTie: true, cancellationToken);
var declarationTypeToFix = TryGetDeclarationTypeToFix(model, node, cancellationToken);
if (declarationTypeToFix == null)
return;
RegisterCodeFix(context, CSharpCodeFixesResources.Declare_as_nullable, GetEquivalenceKey(node, model));
}
private static string GetEquivalenceKey(SyntaxNode node, SemanticModel model)
{
return IsRemoteApiUsage(node, model)
? AssigningNullLiteralRemotelyEquivalenceKey
: node.IsKind(SyntaxKind.ConditionalAccessExpression)
? ConditionalOperatorEquivalenceKey
: AssigningNullLiteralLocallyEquivalenceKey;
static bool IsRemoteApiUsage(SyntaxNode node, SemanticModel model)
{
if (node.IsParentKind(SyntaxKind.Argument))
{
// M(null) could be used in a test
return true;
}
if (node.Parent is AssignmentExpressionSyntax assignment)
{
var symbol = model.GetSymbolInfo(assignment.Left).Symbol;
if (symbol is IFieldSymbol)
{
// x.field could be used in a test
return true;
}
else if (symbol is IPropertySymbol)
{
// x.Property could be used in a test
return true;
}
}
return false;
}
}
protected override async Task FixAllAsync(
Document document,
ImmutableArray<Diagnostic> diagnostics,
SyntaxEditor editor,
CodeActionOptionsProvider fallbackOptions,
CancellationToken cancellationToken)
{
// a method can have multiple `return null;` statements, but we should only fix its return type once
using var _ = PooledHashSet<TypeSyntax>.GetInstance(out var alreadyHandled);
var model = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
foreach (var diagnostic in diagnostics)
{
var node = diagnostic.Location.FindNode(getInnermostNodeForTie: true, cancellationToken);
MakeDeclarationNullable(editor, model, node, alreadyHandled, cancellationToken);
}
}
protected override bool IncludeDiagnosticDuringFixAll(Diagnostic diagnostic, Document document, SemanticModel model, string? equivalenceKey, CancellationToken cancellationToken)
{
var node = diagnostic.Location.FindNode(getInnermostNodeForTie: true, cancellationToken);
return equivalenceKey == GetEquivalenceKey(node, model);
}
private static void MakeDeclarationNullable(
SyntaxEditor editor, SemanticModel model, SyntaxNode node, HashSet<TypeSyntax> alreadyHandled, CancellationToken cancellationToken)
{
var declarationTypeToFix = TryGetDeclarationTypeToFix(model, node, cancellationToken);
if (declarationTypeToFix != null && alreadyHandled.Add(declarationTypeToFix))
{
var fixedDeclaration = SyntaxFactory.NullableType(declarationTypeToFix.WithoutTrivia()).WithTriviaFrom(declarationTypeToFix);
editor.ReplaceNode(declarationTypeToFix, fixedDeclaration);
}
}
private static TypeSyntax? TryGetDeclarationTypeToFix(
SemanticModel model, SyntaxNode node, CancellationToken cancellationToken)
{
if (!IsExpressionSupported(node))
return null;
if (node.Parent is (kind: SyntaxKind.ReturnStatement or SyntaxKind.YieldReturnStatement))
{
var containingMember = node.GetAncestors().FirstOrDefault(
a => a.Kind() is
SyntaxKind.MethodDeclaration or
SyntaxKind.PropertyDeclaration or
SyntaxKind.ParenthesizedLambdaExpression or
SyntaxKind.SimpleLambdaExpression or
SyntaxKind.LocalFunctionStatement or
SyntaxKind.AnonymousMethodExpression or
SyntaxKind.ConstructorDeclaration or
SyntaxKind.DestructorDeclaration or
SyntaxKind.OperatorDeclaration or
SyntaxKind.IndexerDeclaration or
SyntaxKind.EventDeclaration);
if (containingMember == null)
return null;
var onYield = node.IsParentKind(SyntaxKind.YieldReturnStatement);
return containingMember switch
{
MethodDeclarationSyntax method =>
// string M() { return null; }
// async Task<string> M() { return null; }
// IEnumerable<string> M() { yield return null; }
TryGetReturnType(method.ReturnType, method.Modifiers, onYield),
LocalFunctionStatementSyntax localFunction =>
// string local() { return null; }
// async Task<string> local() { return null; }
// IEnumerable<string> local() { yield return null; }
TryGetReturnType(localFunction.ReturnType, localFunction.Modifiers, onYield),
PropertyDeclarationSyntax property =>
// string x { get { return null; } }
// IEnumerable<string> Property { get { yield return null; } }
TryGetReturnType(property.Type, modifiers: default, onYield),
_ => null,
};
}
// string x = null;
if (node.Parent?.Parent?.Parent is VariableDeclarationSyntax variableDeclaration)
{
// string x = null, y = null;
return variableDeclaration.Variables.Count == 1 ? variableDeclaration.Type : null;
}
// x = null;
if (node.Parent is AssignmentExpressionSyntax assignment)
{
var symbol = model.GetSymbolInfo(assignment.Left, cancellationToken).Symbol;
if (symbol is ILocalSymbol { DeclaringSyntaxReferences.Length: > 0 } local)
{
var syntax = local.DeclaringSyntaxReferences[0].GetSyntax(cancellationToken);
if (syntax is VariableDeclaratorSyntax { Parent: VariableDeclarationSyntax { Variables.Count: 1 } declaration })
return declaration.Type;
}
else if (symbol is IParameterSymbol parameter)
{
return TryGetParameterTypeSyntax(parameter, cancellationToken);
}
else if (symbol is IFieldSymbol { IsImplicitlyDeclared: false, DeclaringSyntaxReferences.Length: > 0 } field)
{
// implicitly declared fields don't have DeclaringSyntaxReferences so filter them out
var syntax = field.DeclaringSyntaxReferences[0].GetSyntax(cancellationToken);
if (syntax is VariableDeclaratorSyntax { Parent: VariableDeclarationSyntax { Variables.Count: 1 } declaration })
return declaration.Type;
if (syntax is TupleElementSyntax tupleElement)
return tupleElement.Type;
}
else if (symbol is IFieldSymbol { CorrespondingTupleField: IFieldSymbol { Locations: [{ IsInSource: true } location] } })
{
// Assigning a tuple field, eg. foo.Item1 = null
// The tupleField won't have DeclaringSyntaxReferences because it's implicitly declared, otherwise it
// would have fallen into the branch above. We can use the Locations instead, if there is one and it's in source
if (location.FindNode(cancellationToken) is TupleElementSyntax tupleElement)
return tupleElement.Type;
}
else if (symbol is IPropertySymbol { DeclaringSyntaxReferences.Length: > 0 } property)
{
var syntax = property.DeclaringSyntaxReferences[0].GetSyntax(cancellationToken);
if (syntax is PropertyDeclarationSyntax declaration)
return declaration.Type;
}
return null;
}
// Method(null)
if (node.Parent is ArgumentSyntax argument && argument.Parent?.Parent is InvocationExpressionSyntax invocation)
{
var symbol = model.GetSymbolInfo(invocation.Expression, cancellationToken).Symbol;
if (symbol is not IMethodSymbol method || method.PartialImplementationPart is not null)
{
// https://github.com/dotnet/roslyn/issues/73772: should we also bail out on a partial property?
// We don't handle partial methods yet
return null;
}
if (argument.NameColon?.Name is IdentifierNameSyntax { Identifier: var identifier })
{
var parameter = method.Parameters.Where(p => p.Name == identifier.Text).FirstOrDefault();
return TryGetParameterTypeSyntax(parameter, cancellationToken);
}
var index = invocation.ArgumentList.Arguments.IndexOf(argument);
if (index >= 0 && index < method.Parameters.Length)
{
var parameter = method.Parameters[index];
return TryGetParameterTypeSyntax(parameter, cancellationToken);
}
return null;
}
// string x { get; set; } = null;
if (node.Parent?.Parent is PropertyDeclarationSyntax propertyDeclaration)
return propertyDeclaration.Type;
// string x { get; }
// Unassigned value that's not marked as null
if (node is PropertyDeclarationSyntax propertyDeclarationSyntax)
return propertyDeclarationSyntax.Type;
// string x;
// Unassigned value that's not marked as null
if (node is VariableDeclaratorSyntax { Parent: VariableDeclarationSyntax { Parent: FieldDeclarationSyntax, Variables.Count: 1 } declarationSyntax })
return declarationSyntax.Type;
// void M(string x = null) { }
if (node.Parent?.Parent is ParameterSyntax optionalParameter)
{
var parameterSymbol = model.GetDeclaredSymbol(optionalParameter, cancellationToken);
return TryGetParameterTypeSyntax(parameterSymbol, cancellationToken);
}
// static string M() => null;
if (node.IsParentKind(SyntaxKind.ArrowExpressionClause) &&
node.Parent?.Parent is MethodDeclarationSyntax arrowMethod)
{
return arrowMethod.ReturnType;
}
return null;
// local functions
static TypeSyntax? TryGetReturnType(TypeSyntax returnType, SyntaxTokenList modifiers, bool onYield)
{
if (modifiers.Any(SyntaxKind.AsyncKeyword) || onYield)
{
// async Task<string> M() { return null; }
// async IAsyncEnumerable<string> M() { yield return null; }
// IEnumerable<string> M() { yield return null; }
return TryGetSingleTypeArgument(returnType);
}
// string M() { return null; }
return returnType;
}
static TypeSyntax? TryGetSingleTypeArgument(TypeSyntax type)
{
switch (type)
{
case QualifiedNameSyntax qualified:
return TryGetSingleTypeArgument(qualified.Right);
case GenericNameSyntax generic:
var typeArguments = generic.TypeArgumentList.Arguments;
if (typeArguments.Count == 1)
return typeArguments[0];
break;
}
return null;
}
static TypeSyntax? TryGetParameterTypeSyntax(IParameterSymbol? parameterSymbol, CancellationToken cancellationToken)
{
if (parameterSymbol?.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax(cancellationToken) is ParameterSyntax parameterSyntax &&
parameterSymbol.ContainingSymbol is IMethodSymbol method &&
method.GetAllMethodSymbolsOfPartialParts().Length == 1)
{
return parameterSyntax.Type;
}
return null;
}
}
private static bool IsExpressionSupported(SyntaxNode node)
=> node.Kind() is
SyntaxKind.NullLiteralExpression or
SyntaxKind.AsExpression or
SyntaxKind.DefaultExpression or
SyntaxKind.DefaultLiteralExpression or
SyntaxKind.ConditionalExpression or
SyntaxKind.ConditionalAccessExpression or
SyntaxKind.PropertyDeclaration or
SyntaxKind.VariableDeclarator;
}