Skip to content

Commit c80f459

Browse files
authored
Don't capture async locals in resolver (#2426)
1 parent 63914f2 commit c80f459

File tree

4 files changed

+81
-9
lines changed

4 files changed

+81
-9
lines changed

src/Grpc.Net.Client/Balancer/PollingResolver.cs

+26-7
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,33 @@ public sealed override void Refresh()
135135

136136
if (_resolveTask.IsCompleted)
137137
{
138-
// Run ResolveAsync in a background task.
139-
// This is done to prevent synchronous block inside ResolveAsync from blocking future Refresh calls.
140-
_resolveTask = Task.Run(() => ResolveNowAsync(_cts.Token), _cts.Token);
141-
_resolveTask.ContinueWith(static (t, state) =>
138+
// Don't capture the current ExecutionContext and its AsyncLocals onto the connect
139+
var restoreFlow = false;
140+
try
141+
{
142+
if (!ExecutionContext.IsFlowSuppressed())
143+
{
144+
ExecutionContext.SuppressFlow();
145+
restoreFlow = true;
146+
}
147+
148+
// Run ResolveAsync in a background task.
149+
// This is done to prevent synchronous block inside ResolveAsync from blocking future Refresh calls.
150+
_resolveTask = Task.Run(() => ResolveNowAsync(_cts.Token), _cts.Token);
151+
_resolveTask.ContinueWith(static (t, state) =>
152+
{
153+
var pollingResolver = (PollingResolver)state!;
154+
Log.ResolveTaskCompleted(pollingResolver._logger, pollingResolver.GetType());
155+
}, this);
156+
}
157+
finally
142158
{
143-
var pollingResolver = (PollingResolver)state!;
144-
Log.ResolveTaskCompleted(pollingResolver._logger, pollingResolver.GetType());
145-
}, this);
159+
// Restore the current ExecutionContext
160+
if (restoreFlow)
161+
{
162+
ExecutionContext.RestoreFlow();
163+
}
164+
}
146165
}
147166
else
148167
{

src/Grpc.Net.Client/Balancer/Subchannel.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ public void RequestConnection()
257257
}
258258

259259
// Don't capture the current ExecutionContext and its AsyncLocals onto the connect
260-
bool restoreFlow = false;
260+
var restoreFlow = false;
261261
if (!ExecutionContext.IsFlowSuppressed())
262262
{
263263
ExecutionContext.SuppressFlow();

src/Shared/NonCapturingTimer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public static Timer Create(TimerCallback callback, object? state, TimeSpan dueTi
1313
ArgumentNullThrowHelper.ThrowIfNull(callback);
1414

1515
// Don't capture the current ExecutionContext and its AsyncLocals onto the timer
16-
bool restoreFlow = false;
16+
var restoreFlow = false;
1717
try
1818
{
1919
if (!ExecutionContext.IsFlowSuppressed())

test/Grpc.Net.Client.Tests/Balancer/ResolverTests.cs

+53
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,59 @@ protected override Task ResolveAsync(CancellationToken cancellationToken)
111111
}
112112
}
113113

114+
[Test]
115+
public async Task Refresh_AsyncLocal_NotCaptured()
116+
{
117+
// Arrange
118+
var services = new ServiceCollection();
119+
services.AddNUnitLogger();
120+
var loggerFactory = services.BuildServiceProvider().GetRequiredService<ILoggerFactory>();
121+
122+
var asyncLocal = new AsyncLocal<object>();
123+
asyncLocal.Value = new object();
124+
125+
var callbackAsyncLocalValues = new List<object>();
126+
127+
var resolver = new CallbackPollingResolver(loggerFactory, new TestBackoffPolicyFactory(TimeSpan.FromMilliseconds(100)), (listener) =>
128+
{
129+
callbackAsyncLocalValues.Add(asyncLocal.Value);
130+
if (callbackAsyncLocalValues.Count >= 2)
131+
{
132+
listener(ResolverResult.ForResult(new List<BalancerAddress>()));
133+
}
134+
135+
return Task.CompletedTask;
136+
});
137+
138+
var tcs = new TaskCompletionSource<ResolverResult>(TaskCreationOptions.RunContinuationsAsynchronously);
139+
resolver.Start(result => tcs.TrySetResult(result));
140+
141+
// Act
142+
resolver.Refresh();
143+
144+
// Assert
145+
await tcs.Task.DefaultTimeout();
146+
147+
Assert.AreEqual(2, callbackAsyncLocalValues.Count);
148+
Assert.IsNull(callbackAsyncLocalValues[0]);
149+
Assert.IsNull(callbackAsyncLocalValues[1]);
150+
}
151+
152+
private class CallbackPollingResolver : PollingResolver
153+
{
154+
private readonly Func<Action<ResolverResult>, Task> _callback;
155+
156+
public CallbackPollingResolver(ILoggerFactory loggerFactory, IBackoffPolicyFactory backoffPolicyFactory, Func<Action<ResolverResult>, Task> callback) : base(loggerFactory, backoffPolicyFactory)
157+
{
158+
_callback = callback;
159+
}
160+
161+
protected override Task ResolveAsync(CancellationToken cancellationToken)
162+
{
163+
return _callback(Listener);
164+
}
165+
}
166+
114167
[Test]
115168
public async Task Resolver_ResolveNameFromServices_Success()
116169
{

0 commit comments

Comments
 (0)