Skip to content

Commit 63914f2

Browse files
authored
Add semaphore to limit subchannel connect to prevent race conditions (#2422)
1 parent 8199f66 commit 63914f2

File tree

3 files changed

+126
-27
lines changed

3 files changed

+126
-27
lines changed

src/Grpc.Net.Client/Balancer/Subchannel.cs

+83-26
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public sealed class Subchannel : IDisposable
5454

5555
internal readonly ConnectionManager _manager;
5656
private readonly ILogger _logger;
57+
private readonly SemaphoreSlim _connectSemaphore;
5758

5859
private ISubchannelTransport _transport = default!;
5960
private ConnectContext? _connectContext;
@@ -89,6 +90,7 @@ internal Subchannel(ConnectionManager manager, IReadOnlyList<BalancerAddress> ad
8990
{
9091
Lock = new object();
9192
_logger = manager.LoggerFactory.CreateLogger(GetType());
93+
_connectSemaphore = new SemaphoreSlim(1);
9294

9395
Id = manager.GetNextId();
9496
_addresses = addresses.ToList();
@@ -213,7 +215,10 @@ public void UpdateAddresses(IReadOnlyList<BalancerAddress> addresses)
213215

214216
if (requireReconnect)
215217
{
216-
CancelInProgressConnect();
218+
lock (Lock)
219+
{
220+
CancelInProgressConnectUnsynchronized();
221+
}
217222
_transport.Disconnect();
218223
RequestConnection();
219224
}
@@ -268,43 +273,76 @@ public void RequestConnection()
268273
}
269274
}
270275

271-
private void CancelInProgressConnect()
276+
private void CancelInProgressConnectUnsynchronized()
272277
{
273-
lock (Lock)
274-
{
275-
if (_connectContext != null && !_connectContext.Disposed)
276-
{
277-
SubchannelLog.CancelingConnect(_logger, Id);
278+
Debug.Assert(Monitor.IsEntered(Lock));
278279

279-
// Cancel connect cancellation token.
280-
_connectContext.CancelConnect();
281-
_connectContext.Dispose();
282-
}
280+
if (_connectContext != null && !_connectContext.Disposed)
281+
{
282+
SubchannelLog.CancelingConnect(_logger, Id);
283283

284-
_delayInterruptTcs?.TrySetResult(null);
284+
// Cancel connect cancellation token.
285+
_connectContext.CancelConnect();
286+
_connectContext.Dispose();
285287
}
288+
289+
_delayInterruptTcs?.TrySetResult(null);
286290
}
287291

288-
private ConnectContext GetConnectContext()
292+
private ConnectContext GetConnectContextUnsynchronized()
289293
{
290-
lock (Lock)
291-
{
292-
// There shouldn't be a previous connect in progress, but cancel the CTS to ensure they're no longer running.
293-
CancelInProgressConnect();
294+
Debug.Assert(Monitor.IsEntered(Lock));
294295

295-
var connectContext = _connectContext = new ConnectContext(_transport.ConnectTimeout ?? Timeout.InfiniteTimeSpan);
296-
return connectContext;
297-
}
296+
// There shouldn't be a previous connect in progress, but cancel the CTS to ensure they're no longer running.
297+
CancelInProgressConnectUnsynchronized();
298+
299+
var connectContext = _connectContext = new ConnectContext(_transport.ConnectTimeout ?? Timeout.InfiniteTimeSpan);
300+
return connectContext;
298301
}
299302

300303
private async Task ConnectTransportAsync()
301304
{
302-
var connectContext = GetConnectContext();
305+
ConnectContext connectContext;
306+
Task? waitSemaporeTask = null;
307+
lock (Lock)
308+
{
309+
// Don't start connecting if the subchannel has been shutdown. Transport/semaphore will be disposed if shutdown.
310+
if (_state == ConnectivityState.Shutdown)
311+
{
312+
return;
313+
}
314+
315+
connectContext = GetConnectContextUnsynchronized();
316+
317+
// Use a semaphore to limit one connection attempt at a time. This is done to prevent a race conditional where a canceled connect
318+
// overwrites the status of a successful connect.
319+
//
320+
// Try to get semaphore without waiting. If semaphore is already taken then start a task to wait for it to be released.
321+
// Start this inside a lock to make sure subchannel isn't shutdown before waiting for semaphore.
322+
if (!_connectSemaphore.Wait(0))
323+
{
324+
SubchannelLog.QueuingConnect(_logger, Id);
325+
waitSemaporeTask = _connectSemaphore.WaitAsync(connectContext.CancellationToken);
326+
}
327+
}
303328

304-
var backoffPolicy = _manager.BackoffPolicyFactory.Create();
329+
if (waitSemaporeTask != null)
330+
{
331+
try
332+
{
333+
await waitSemaporeTask.ConfigureAwait(false);
334+
}
335+
catch (OperationCanceledException)
336+
{
337+
// Canceled while waiting for semaphore.
338+
return;
339+
}
340+
}
305341

306342
try
307343
{
344+
var backoffPolicy = _manager.BackoffPolicyFactory.Create();
345+
308346
SubchannelLog.ConnectingTransport(_logger, Id);
309347

310348
for (var attempt = 0; ; attempt++)
@@ -384,6 +422,13 @@ private async Task ConnectTransportAsync()
384422
// Dispose context because it might have been created with a connect timeout.
385423
// Want to clean up the connect timeout timer.
386424
connectContext.Dispose();
425+
426+
// Subchannel could have been disposed while connect is running.
427+
// If subchannel is shutting down then don't release semaphore to avoid ObjectDisposedException.
428+
if (_state != ConnectivityState.Shutdown)
429+
{
430+
_connectSemaphore.Release();
431+
}
387432
}
388433
}
389434
}
@@ -482,8 +527,12 @@ public void Dispose()
482527
}
483528
_stateChangedRegistrations.Clear();
484529

485-
CancelInProgressConnect();
486-
_transport.Dispose();
530+
lock (Lock)
531+
{
532+
CancelInProgressConnectUnsynchronized();
533+
_transport.Dispose();
534+
_connectSemaphore.Dispose();
535+
}
487536
}
488537
}
489538

@@ -505,7 +554,7 @@ internal static class SubchannelLog
505554
LoggerMessage.Define<string, ConnectivityState>(LogLevel.Debug, new EventId(5, "ConnectionRequestedInNonIdleState"), "Subchannel id '{SubchannelId}' connection requested in non-idle state of {State}.");
506555

507556
private static readonly Action<ILogger, string, Exception?> _connectingTransport =
508-
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(6, "ConnectingTransport"), "Subchannel id '{SubchannelId}' connecting to transport.");
557+
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(6, "ConnectingTransport"), "Subchannel id '{SubchannelId}' connecting to transport.");
509558

510559
private static readonly Action<ILogger, string, TimeSpan, Exception?> _startingConnectBackoff =
511560
LoggerMessage.Define<string, TimeSpan>(LogLevel.Trace, new EventId(7, "StartingConnectBackoff"), "Subchannel id '{SubchannelId}' starting connect backoff of {BackoffDuration}.");
@@ -514,7 +563,7 @@ internal static class SubchannelLog
514563
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(8, "ConnectBackoffInterrupted"), "Subchannel id '{SubchannelId}' connect backoff interrupted.");
515564

516565
private static readonly Action<ILogger, string, Exception?> _connectCanceled =
517-
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(9, "ConnectCanceled"), "Subchannel id '{SubchannelId}' connect canceled.");
566+
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(9, "ConnectCanceled"), "Subchannel id '{SubchannelId}' connect canceled.");
518567

519568
private static readonly Action<ILogger, string, Exception?> _connectError =
520569
LoggerMessage.Define<string>(LogLevel.Error, new EventId(10, "ConnectError"), "Subchannel id '{SubchannelId}' unexpected error while connecting to transport.");
@@ -546,6 +595,9 @@ internal static class SubchannelLog
546595
private static readonly Action<ILogger, string, string, Exception?> _addressesUpdated =
547596
LoggerMessage.Define<string, string>(LogLevel.Trace, new EventId(19, "AddressesUpdated"), "Subchannel id '{SubchannelId}' updated with addresses: {Addresses}");
548597

598+
private static readonly Action<ILogger, string, Exception?> _queuingConnect =
599+
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(20, "QueuingConnect"), "Subchannel id '{SubchannelId}' queuing connect because a connect is already in progress.");
600+
549601
public static void SubchannelCreated(ILogger logger, string subchannelId, IReadOnlyList<BalancerAddress> addresses)
550602
{
551603
if (logger.IsEnabled(LogLevel.Debug))
@@ -648,5 +700,10 @@ public static void AddressesUpdated(ILogger logger, string subchannelId, IReadOn
648700
_addressesUpdated(logger, subchannelId, addressesText, null);
649701
}
650702
}
703+
704+
public static void QueuingConnect(ILogger logger, string subchannelId)
705+
{
706+
_queuingConnect(logger, subchannelId, null);
707+
}
651708
}
652709
#endif

test/FunctionalTests/Balancer/ConnectionTests.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,11 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
9393

9494
var client = TestClientFactory.Create(channel, endpoint.Method);
9595

96+
// Act
9697
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => client.UnaryCall(new HelloRequest()).ResponseAsync).DefaultTimeout();
9798
Assert.AreEqual("A connection could not be established within the configured ConnectTimeout.", ex.Status.DebugException!.Message);
99+
100+
await ExceptionAssert.ThrowsAsync<OperationCanceledException>(() => connectTcs.Task).DefaultTimeout();
98101
}
99102

100103
[Test]
@@ -167,7 +170,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
167170
connectionIdleTimeout: connectionIdleTimeout).DefaultTimeout();
168171

169172
Logger.LogInformation("Connecting channel.");
170-
await channel.ConnectAsync();
173+
await channel.ConnectAsync().DefaultTimeout();
171174

172175
// Wait for timeout plus a little extra to avoid issues from imprecise timers.
173176
await Task.Delay(connectionIdleTimeout + TimeSpan.FromMilliseconds(50));

test/FunctionalTests/Balancer/PickFirstBalancerTests.cs

+39
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,45 @@ private GrpcChannel CreateGrpcWebChannel(TestServerEndpointName endpointName, Ve
6161
return channel;
6262
}
6363

64+
[Test]
65+
public async Task UnaryCall_CallAfterConnectionTimeout_Success()
66+
{
67+
// Ignore errors
68+
SetExpectedErrorsFilter(writeContext =>
69+
{
70+
return true;
71+
});
72+
73+
string? host = null;
74+
Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
75+
{
76+
host = context.Host;
77+
return Task.FromResult(new HelloReply { Message = request.Name });
78+
}
79+
80+
// Arrange
81+
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));
82+
83+
var connectCount = 0;
84+
var channel = await BalancerHelpers.CreateChannel(LoggerFactory, new PickFirstConfig(), new[] { endpoint.Address }, connectTimeout: TimeSpan.FromMilliseconds(200), socketConnect:
85+
async (socket, endpoint, cancellationToken) =>
86+
{
87+
if (Interlocked.Increment(ref connectCount) == 1)
88+
{
89+
await Task.Delay(1000, cancellationToken);
90+
}
91+
await socket.ConnectAsync(endpoint, cancellationToken);
92+
}).DefaultTimeout();
93+
var client = TestClientFactory.Create(channel, endpoint.Method);
94+
95+
// Assert
96+
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => client.UnaryCall(new HelloRequest { Name = "Balancer" }).ResponseAsync).DefaultTimeout();
97+
Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode);
98+
Assert.IsInstanceOf(typeof(TimeoutException), ex.InnerException);
99+
100+
await client.UnaryCall(new HelloRequest { Name = "Balancer" }).ResponseAsync.DefaultTimeout();
101+
}
102+
64103
[Test]
65104
public async Task UnaryCall_CallAfterCancellation_Success()
66105
{

0 commit comments

Comments
 (0)