Skip to content

Commit 1d12340

Browse files
authored
Support idle connection timeout with pending sockets (#2213)
1 parent 0ab3ada commit 1d12340

File tree

6 files changed

+166
-10
lines changed

6 files changed

+166
-10
lines changed

src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs

+31-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi
5656
private readonly ILogger _logger;
5757
private readonly Subchannel _subchannel;
5858
private readonly TimeSpan _socketPingInterval;
59+
private readonly TimeSpan _socketIdleTimeout;
5960
private readonly Func<Socket, DnsEndPoint, CancellationToken, ValueTask> _socketConnect;
6061
private readonly List<ActiveStream> _activeStreams;
6162
private readonly Timer _socketConnectedTimer;
@@ -64,20 +65,23 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi
6465
internal Socket? _initialSocket;
6566
private BalancerAddress? _initialSocketAddress;
6667
private List<ReadOnlyMemory<byte>>? _initialSocketData;
68+
private DateTime? _initialSocketCreatedTime;
6769
private bool _disposed;
6870
private BalancerAddress? _currentAddress;
6971

7072
public SocketConnectivitySubchannelTransport(
7173
Subchannel subchannel,
7274
TimeSpan socketPingInterval,
7375
TimeSpan? connectTimeout,
76+
TimeSpan socketIdleTimeout,
7477
ILoggerFactory loggerFactory,
7578
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
7679
{
7780
_logger = loggerFactory.CreateLogger<SocketConnectivitySubchannelTransport>();
7881
_subchannel = subchannel;
7982
_socketPingInterval = socketPingInterval;
8083
ConnectTimeout = connectTimeout;
84+
_socketIdleTimeout = socketIdleTimeout;
8185
_socketConnect = socketConnect ?? OnConnect;
8286
_activeStreams = new List<ActiveStream>();
8387
_socketConnectedTimer = NonCapturingTimer.Create(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
@@ -125,6 +129,7 @@ private void DisconnectUnsynchronized()
125129
_initialSocket = null;
126130
_initialSocketAddress = null;
127131
_initialSocketData = null;
132+
_initialSocketCreatedTime = null;
128133
_lastEndPointIndex = 0;
129134
_currentAddress = null;
130135
}
@@ -162,6 +167,7 @@ public async ValueTask<ConnectResult> TryConnectAsync(ConnectContext context)
162167
_initialSocket = socket;
163168
_initialSocketAddress = currentAddress;
164169
_initialSocketData = null;
170+
_initialSocketCreatedTime = DateTime.UtcNow;
165171

166172
// Schedule ping. Don't set a periodic interval to avoid any chance of timer causing the target method to run multiple times in paralle.
167173
// This could happen because of execution delays (e.g. hitting a debugger breakpoint).
@@ -338,6 +344,7 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
338344
Socket? socket = null;
339345
BalancerAddress? socketAddress = null;
340346
List<ReadOnlyMemory<byte>>? socketData = null;
347+
DateTime? socketCreatedTime = null;
341348
lock (Lock)
342349
{
343350
if (_initialSocket != null)
@@ -347,9 +354,11 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
347354
socket = _initialSocket;
348355
socketAddress = _initialSocketAddress;
349356
socketData = _initialSocketData;
357+
socketCreatedTime = _initialSocketCreatedTime;
350358
_initialSocket = null;
351359
_initialSocketAddress = null;
352360
_initialSocketData = null;
361+
_initialSocketCreatedTime = null;
353362

354363
// Double check the address matches the socket address and only use socket on match.
355364
// Not sure if this is possible in practice, but better safe than sorry.
@@ -365,10 +374,23 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
365374

366375
if (socket != null)
367376
{
368-
if (IsSocketInBadState(socket, address))
377+
Debug.Assert(socketCreatedTime != null);
378+
379+
var closeSocket = false;
380+
381+
if (DateTime.UtcNow > socketCreatedTime.Value.Add(_socketIdleTimeout))
382+
{
383+
SocketConnectivitySubchannelTransportLog.ClosingSocketFromIdleTimeoutOnCreateStream(_logger, _subchannel.Id, address, _socketIdleTimeout);
384+
closeSocket = true;
385+
}
386+
else if (IsSocketInBadState(socket, address))
369387
{
370388
SocketConnectivitySubchannelTransportLog.ClosingUnusableSocketOnCreateStream(_logger, _subchannel.Id, address);
389+
closeSocket = true;
390+
}
371391

392+
if (closeSocket)
393+
{
372394
socket.Dispose();
373395
socket = null;
374396
socketData = null;
@@ -530,6 +552,9 @@ internal static class SocketConnectivitySubchannelTransportLog
530552
private static readonly Action<ILogger, int, BalancerAddress, Exception?> _closingUnusableSocketOnCreateStream =
531553
LoggerMessage.Define<int, BalancerAddress>(LogLevel.Debug, new EventId(16, "ClosingUnusableSocketOnCreateStream"), "Subchannel id '{SubchannelId}' socket {Address} is being closed because it can't be used. The socket either can't receive data or it has received unexpected data.");
532554

555+
private static readonly Action<ILogger, int, BalancerAddress, TimeSpan, Exception?> _closingSocketFromIdleTimeoutOnCreateStream =
556+
LoggerMessage.Define<int, BalancerAddress, TimeSpan>(LogLevel.Debug, new EventId(16, "ClosingSocketFromIdleTimeoutOnCreateStream"), "Subchannel id '{SubchannelId}' socket {Address} is being closed because it exceeds the idle timeout of {SocketIdleTimeout}.");
557+
533558
public static void ConnectingSocket(ILogger logger, int subchannelId, BalancerAddress address)
534559
{
535560
_connectingSocket(logger, subchannelId, address, null);
@@ -609,5 +634,10 @@ public static void ClosingUnusableSocketOnCreateStream(ILogger logger, int subch
609634
{
610635
_closingUnusableSocketOnCreateStream(logger, subchannelId, address, null);
611636
}
637+
638+
public static void ClosingSocketFromIdleTimeoutOnCreateStream(ILogger logger, int subchannelId, BalancerAddress address, TimeSpan socketIdleTimeout)
639+
{
640+
_closingSocketFromIdleTimeoutOnCreateStream(logger, subchannelId, address, socketIdleTimeout, null);
641+
}
612642
}
613643
#endif

src/Grpc.Net.Client/GrpcChannel.cs

+10-4
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ public sealed class GrpcChannel : ChannelBase, IDisposable
6161
internal Uri Address { get; }
6262
internal HttpMessageInvoker HttpInvoker { get; }
6363
internal TimeSpan? ConnectTimeout { get; }
64+
internal TimeSpan? ConnectionIdleTimeout { get; }
6465
internal HttpHandlerType HttpHandlerType { get; }
6566
internal TimeSpan InitialReconnectBackoff { get; }
6667
internal TimeSpan? MaxReconnectBackoff { get; }
@@ -125,7 +126,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
125126

126127
var resolverFactory = GetResolverFactory(channelOptions);
127128
ResolveCredentials(channelOptions, out _isSecure, out _callCredentials);
128-
(HttpHandlerType, ConnectTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);
129+
(HttpHandlerType, ConnectTimeout, ConnectionIdleTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);
129130

130131
SubchannelTransportFactory = channelOptions.ResolveService<ISubchannelTransportFactory>(new SubChannelTransportFactory(this));
131132

@@ -154,7 +155,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
154155
throw new ArgumentException($"Address '{address.OriginalString}' doesn't have a host. Address should include a scheme, host, and optional port. For example, 'https://localhost:5001'.");
155156
}
156157
ResolveCredentials(channelOptions, out _isSecure, out _callCredentials);
157-
(HttpHandlerType, ConnectTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);
158+
(HttpHandlerType, ConnectTimeout, ConnectionIdleTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);
158159
#endif
159160

160161
HttpInvoker = channelOptions.HttpClient ?? CreateInternalHttpInvoker(channelOptions.HttpHandler);
@@ -243,12 +244,14 @@ private static HttpHandlerContext CalculateHandlerContext(ILogger logger, Uri ad
243244
{
244245
HttpHandlerType type;
245246
TimeSpan? connectTimeout;
247+
TimeSpan? connectionIdleTimeout;
246248

247249
#if NET5_0_OR_GREATER
248250
var socketsHttpHandler = HttpRequestHelpers.GetHttpHandlerType<SocketsHttpHandler>(channelOptions.HttpHandler)!;
249251

250252
type = HttpHandlerType.SocketsHttpHandler;
251253
connectTimeout = socketsHttpHandler.ConnectTimeout;
254+
connectionIdleTimeout = socketsHttpHandler.PooledConnectionIdleTimeout;
252255

253256
// Check if the SocketsHttpHandler is being shared by channels.
254257
// It has already been setup by another channel (i.e. ConnectCallback is set) then
@@ -261,6 +264,7 @@ private static HttpHandlerContext CalculateHandlerContext(ILogger logger, Uri ad
261264
{
262265
type = HttpHandlerType.Custom;
263266
connectTimeout = null;
267+
connectionIdleTimeout = null;
264268
}
265269
}
266270

@@ -282,8 +286,9 @@ private static HttpHandlerContext CalculateHandlerContext(ILogger logger, Uri ad
282286
#else
283287
type = HttpHandlerType.SocketsHttpHandler;
284288
connectTimeout = null;
289+
connectionIdleTimeout = null;
285290
#endif
286-
return new HttpHandlerContext(type, connectTimeout);
291+
return new HttpHandlerContext(type, connectTimeout, connectionIdleTimeout);
287292
}
288293
if (HttpRequestHelpers.GetHttpHandlerType<HttpClientHandler>(channelOptions.HttpHandler) != null)
289294
{
@@ -837,6 +842,7 @@ public ISubchannelTransport Create(Subchannel subchannel)
837842
subchannel,
838843
SocketConnectivitySubchannelTransport.SocketPingInterval,
839844
_channel.ConnectTimeout,
845+
_channel.ConnectionIdleTimeout ?? TimeSpan.FromMinutes(1),
840846
_channel.LoggerFactory,
841847
socketConnect: null);
842848
}
@@ -895,7 +901,7 @@ public static void AddressPathUnused(ILogger logger, string address)
895901
}
896902
}
897903

898-
private readonly record struct HttpHandlerContext(HttpHandlerType HttpHandlerType, TimeSpan? ConnectTimeout = null);
904+
private readonly record struct HttpHandlerContext(HttpHandlerType HttpHandlerType, TimeSpan? ConnectTimeout = null, TimeSpan? ConnectionIdleTimeout = null);
899905
}
900906

901907
internal enum HttpHandlerType

test/FunctionalTests/Balancer/BalancerHelpers.cs

+10-5
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,14 @@ public static Task<GrpcChannel> CreateChannel(
135135
bool? connect = null,
136136
RetryPolicy? retryPolicy = null,
137137
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
138-
TimeSpan? connectTimeout = null)
138+
TimeSpan? connectTimeout = null,
139+
TimeSpan? connectionIdleTimeout = null)
139140
{
140141
var resolver = new TestResolver();
141142
var e = endpoints.Select(i => new BalancerAddress(i.Host, i.Port)).ToList();
142143
resolver.UpdateAddresses(e);
143144

144-
return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout);
145+
return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout);
145146
}
146147

147148
public static async Task<GrpcChannel> CreateChannel(
@@ -152,12 +153,13 @@ public static async Task<GrpcChannel> CreateChannel(
152153
bool? connect = null,
153154
RetryPolicy? retryPolicy = null,
154155
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
155-
TimeSpan? connectTimeout = null)
156+
TimeSpan? connectTimeout = null,
157+
TimeSpan? connectionIdleTimeout = null)
156158
{
157159
var services = new ServiceCollection();
158160
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
159161
services.AddSingleton<IRandomGenerator>(new TestRandomGenerator());
160-
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, socketConnect));
162+
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
161163
services.AddSingleton<LoadBalancerFactory>(new LeastUsedBalancerFactory());
162164

163165
var serviceConfig = new ServiceConfig();
@@ -214,12 +216,14 @@ internal class TestSubchannelTransportFactory : ISubchannelTransportFactory
214216
{
215217
private readonly TimeSpan _socketPingInterval;
216218
private readonly TimeSpan? _connectTimeout;
219+
private readonly TimeSpan _connectionIdleTimeout;
217220
private readonly Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? _socketConnect;
218221

219-
public TestSubchannelTransportFactory(TimeSpan socketPingInterval, TimeSpan? connectTimeout, Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
222+
public TestSubchannelTransportFactory(TimeSpan socketPingInterval, TimeSpan? connectTimeout, TimeSpan connectionIdleTimeout, Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
220223
{
221224
_socketPingInterval = socketPingInterval;
222225
_connectTimeout = connectTimeout;
226+
_connectionIdleTimeout = connectionIdleTimeout;
223227
_socketConnect = socketConnect;
224228
}
225229

@@ -230,6 +234,7 @@ public ISubchannelTransport Create(Subchannel subchannel)
230234
subchannel,
231235
_socketPingInterval,
232236
_connectTimeout,
237+
_connectionIdleTimeout,
233238
subchannel._manager.LoggerFactory,
234239
_socketConnect);
235240
#else

test/FunctionalTests/Balancer/ConnectionTests.cs

+39
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,45 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
141141
await ExceptionAssert.ThrowsAsync<OperationCanceledException>(() => connectTask).DefaultTimeout();
142142
}
143143

144+
[Test]
145+
public async Task Active_UnaryCall_ConnectionIdleTimeout_SocketRecreated()
146+
{
147+
// Ignore errors
148+
SetExpectedErrorsFilter(writeContext =>
149+
{
150+
return true;
151+
});
152+
153+
Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
154+
{
155+
return Task.FromResult(new HelloReply { Message = request.Name });
156+
}
157+
158+
// Arrange
159+
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));
160+
161+
var connectionIdleTimeout = TimeSpan.FromSeconds(1);
162+
var channel = await BalancerHelpers.CreateChannel(
163+
LoggerFactory,
164+
new PickFirstConfig(),
165+
new[] { endpoint.Address },
166+
connectionIdleTimeout: connectionIdleTimeout).DefaultTimeout();
167+
168+
Logger.LogInformation("Connecting channel.");
169+
await channel.ConnectAsync();
170+
171+
await Task.Delay(connectionIdleTimeout);
172+
173+
var client = TestClientFactory.Create(channel, endpoint.Method);
174+
var response = await client.UnaryCall(new HelloRequest { Name = "Test!" }).ResponseAsync.DefaultTimeout();
175+
176+
// Assert
177+
Assert.AreEqual("Test!", response.Message);
178+
179+
AssertHasLog(LogLevel.Debug, "ClosingSocketFromIdleTimeoutOnCreateStream", "Subchannel id '1' socket 127.0.0.1:50051 is being closed because it exceeds the idle timeout of 00:00:01.");
180+
AssertHasLog(LogLevel.Trace, "ConnectingOnCreateStream", "Subchannel id '1' doesn't have a connected socket available. Connecting new stream socket for 127.0.0.1:50051.");
181+
}
182+
144183
[Test]
145184
public async Task Active_UnaryCall_MultipleStreams_UnavailableAddress_FallbackToWorkingAddress()
146185
{

test/Grpc.Net.Client.Tests/Balancer/StreamWrapperTests.cs

+48
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,54 @@ namespace Grpc.Net.Client.Tests.Balancer;
2525
[TestFixture]
2626
public class StreamWrapperTests
2727
{
28+
[Test]
29+
public async Task ReadAsync_ExactSize_Read()
30+
{
31+
// Arrange
32+
var ms = new MemoryStream(new byte[] { 4 });
33+
var data = new List<ReadOnlyMemory<byte>>
34+
{
35+
new byte[] { 1, 2, 3 }
36+
};
37+
var streamWrapper = new StreamWrapper(ms, s => { }, data);
38+
var buffer = new byte[3];
39+
40+
// Act & Assert
41+
Assert.AreEqual(3, await streamWrapper.ReadAsync(buffer));
42+
Assert.AreEqual(1, buffer[0]);
43+
Assert.AreEqual(2, buffer[1]);
44+
Assert.AreEqual(3, buffer[2]);
45+
46+
Assert.AreEqual(1, await streamWrapper.ReadAsync(buffer));
47+
Assert.AreEqual(4, buffer[0]);
48+
49+
Assert.AreEqual(0, await streamWrapper.ReadAsync(buffer));
50+
}
51+
52+
[Test]
53+
public async Task ReadAsync_BiggerThanNeeded_Read()
54+
{
55+
// Arrange
56+
var ms = new MemoryStream(new byte[] { 4 });
57+
var data = new List<ReadOnlyMemory<byte>>
58+
{
59+
new byte[] { 1, 2, 3 }
60+
};
61+
var streamWrapper = new StreamWrapper(ms, s => { }, data);
62+
var buffer = new byte[4];
63+
64+
// Act & Assert
65+
Assert.AreEqual(3, await streamWrapper.ReadAsync(buffer));
66+
Assert.AreEqual(1, buffer[0]);
67+
Assert.AreEqual(2, buffer[1]);
68+
Assert.AreEqual(3, buffer[2]);
69+
70+
Assert.AreEqual(1, await streamWrapper.ReadAsync(buffer));
71+
Assert.AreEqual(4, buffer[0]);
72+
73+
Assert.AreEqual(0, await streamWrapper.ReadAsync(buffer));
74+
}
75+
2876
[Test]
2977
public async Task ReadAsync_MultipleInitialData_ReadInOrder()
3078
{

test/Grpc.Net.Client.Tests/GrpcChannelTests.cs

+28
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,34 @@ public void Build_InsecureCredentialsWithHttps_ThrowsError()
208208
Assert.AreEqual("Channel is configured with insecure channel credentials and can't use a HttpClient with a 'https' scheme.", ex.Message);
209209
}
210210
211+
#if SUPPORT_LOAD_BALANCING
212+
[Test]
213+
public void Build_ConnectTimeout_ReadFromSocketsHttpHandler()
214+
{
215+
// Arrange & Act
216+
var channel = GrpcChannel.ForAddress("https://localhost", CreateGrpcChannelOptions(o => o.HttpHandler = new SocketsHttpHandler
217+
{
218+
ConnectTimeout = TimeSpan.FromSeconds(1)
219+
}));
220+
221+
// Assert
222+
Assert.AreEqual(TimeSpan.FromSeconds(1), channel.ConnectTimeout);
223+
}
224+
225+
[Test]
226+
public void Build_ConnectionIdleTimeout_ReadFromSocketsHttpHandler()
227+
{
228+
// Arrange & Act
229+
var channel = GrpcChannel.ForAddress("https://localhost", CreateGrpcChannelOptions(o => o.HttpHandler = new SocketsHttpHandler
230+
{
231+
PooledConnectionIdleTimeout = TimeSpan.FromSeconds(1)
232+
}));
233+
234+
// Assert
235+
Assert.AreEqual(TimeSpan.FromSeconds(1), channel.ConnectionIdleTimeout);
236+
}
237+
#endif
238+
211239
[Test]
212240
public void Build_HttpClientAndHttpHandler_ThrowsError()
213241
{

0 commit comments

Comments
 (0)