Skip to content

Commit a2d005c

Browse files
authored
Correctly check socket on stream creation (#2215)
1 parent 1d12340 commit a2d005c

File tree

3 files changed

+139
-57
lines changed

3 files changed

+139
-57
lines changed

src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs

+71-49
Original file line numberDiff line numberDiff line change
@@ -253,53 +253,7 @@ private void OnCheckSocketConnection(object? state)
253253
{
254254
CompatibilityHelpers.Assert(socketAddress != null);
255255

256-
try
257-
{
258-
SocketConnectivitySubchannelTransportLog.CheckingSocket(_logger, _subchannel.Id, socketAddress);
259-
260-
// Poll socket to check if it can be read from. Unfortunatly this requires reading pending data.
261-
// The server might send data, e.g. HTTP/2 SETTINGS frame, so we need to read and cache it.
262-
//
263-
// Available data needs to be read now because the only way to determine whether the connection is closed is to
264-
// get the results of polling after available data is received.
265-
bool hasReadData;
266-
do
267-
{
268-
closeSocket = IsSocketInBadState(socket, socketAddress);
269-
var available = socket.Available;
270-
if (available > 0)
271-
{
272-
hasReadData = true;
273-
var serverDataAvailable = CalculateInitialSocketDataLength(_initialSocketData) + available;
274-
if (serverDataAvailable > MaximumInitialSocketDataSize)
275-
{
276-
// Data sent to the client before a connection is started shouldn't be large.
277-
// Put a maximum limit on the buffer size to prevent an unexpected scenario from consuming too much memory.
278-
throw new InvalidOperationException($"The server sent {serverDataAvailable} bytes to the client before a connection was established. Maximum allowed data exceeded.");
279-
}
280-
281-
SocketConnectivitySubchannelTransportLog.SocketReceivingAvailable(_logger, _subchannel.Id, socketAddress, available);
282-
283-
// Data is already available so this won't block.
284-
var buffer = new byte[available];
285-
var readCount = socket.Receive(buffer);
286-
287-
_initialSocketData ??= new List<ReadOnlyMemory<byte>>();
288-
_initialSocketData.Add(buffer.AsMemory(0, readCount));
289-
}
290-
else
291-
{
292-
hasReadData = false;
293-
}
294-
}
295-
while (hasReadData);
296-
}
297-
catch (Exception ex)
298-
{
299-
closeSocket = true;
300-
checkException = ex;
301-
SocketConnectivitySubchannelTransportLog.ErrorCheckingSocket(_logger, _subchannel.Id, socketAddress, ex);
302-
}
256+
closeSocket = ShouldCloseSocket(socket, socketAddress, ref _initialSocketData, out checkException);
303257
}
304258
}
305259

@@ -383,7 +337,7 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
383337
SocketConnectivitySubchannelTransportLog.ClosingSocketFromIdleTimeoutOnCreateStream(_logger, _subchannel.Id, address, _socketIdleTimeout);
384338
closeSocket = true;
385339
}
386-
else if (IsSocketInBadState(socket, address))
340+
else if (ShouldCloseSocket(socket, address, ref socketData, out _))
387341
{
388342
SocketConnectivitySubchannelTransportLog.ClosingUnusableSocketOnCreateStream(_logger, _subchannel.Id, address);
389343
closeSocket = true;
@@ -419,7 +373,75 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
419373
return stream;
420374
}
421375

422-
private bool IsSocketInBadState(Socket socket, BalancerAddress address)
376+
/// <summary>
377+
/// Checks whether the socket is healthy. May read available data into the passed in buffer.
378+
/// Returns true if the socket should be closed.
379+
/// </summary>
380+
private bool ShouldCloseSocket(Socket socket, BalancerAddress socketAddress, ref List<ReadOnlyMemory<byte>>? socketData, out Exception? checkException)
381+
{
382+
checkException = null;
383+
384+
try
385+
{
386+
SocketConnectivitySubchannelTransportLog.CheckingSocket(_logger, _subchannel.Id, socketAddress);
387+
388+
// Poll socket to check if it can be read from. Unfortunately this requires reading pending data.
389+
// The server might send data, e.g. HTTP/2 SETTINGS frame, so we need to read and cache it.
390+
//
391+
// Available data needs to be read now because the only way to determine whether the connection is
392+
// closed is to get the results of polling after available data is received.
393+
// For example, the server may have sent an HTTP/2 SETTINGS or GOAWAY frame.
394+
// We need to cache whatever we read so it isn't dropped.
395+
do
396+
{
397+
if (PollSocket(socket, socketAddress))
398+
{
399+
// Polling socket reported an unhealthy state.
400+
return true;
401+
}
402+
403+
var available = socket.Available;
404+
if (available > 0)
405+
{
406+
var serverDataAvailable = CalculateInitialSocketDataLength(socketData) + available;
407+
if (serverDataAvailable > MaximumInitialSocketDataSize)
408+
{
409+
// Data sent to the client before a connection is started shouldn't be large.
410+
// Put a maximum limit on the buffer size to prevent an unexpected scenario from consuming too much memory.
411+
throw new InvalidOperationException($"The server sent {serverDataAvailable} bytes to the client before a connection was established. Maximum allowed data exceeded.");
412+
}
413+
414+
SocketConnectivitySubchannelTransportLog.SocketReceivingAvailable(_logger, _subchannel.Id, socketAddress, available);
415+
416+
// Data is already available so this won't block.
417+
var buffer = new byte[available];
418+
var readCount = socket.Receive(buffer);
419+
420+
socketData ??= new List<ReadOnlyMemory<byte>>();
421+
socketData.Add(buffer.AsMemory(0, readCount));
422+
}
423+
else
424+
{
425+
// There is no more available data to read and the socket is healthy.
426+
return false;
427+
}
428+
}
429+
while (true);
430+
}
431+
catch (Exception ex)
432+
{
433+
checkException = ex;
434+
SocketConnectivitySubchannelTransportLog.ErrorCheckingSocket(_logger, _subchannel.Id, socketAddress, ex);
435+
return true;
436+
}
437+
}
438+
439+
/// <summary>
440+
/// Poll the socket to check for health and available data.
441+
/// Shouldn't be used by itself as data needs to be consumed to accurately report the socket health.
442+
/// <see cref="ShouldCloseSocket"/> handles consuming data and getting the socket health.
443+
/// </summary>
444+
private bool PollSocket(Socket socket, BalancerAddress address)
423445
{
424446
// From https://github.com/dotnet/runtime/blob/3195fbbd82fdb7f132d6698591ba6489ad6dd8cf/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs#L158-L168
425447
try

test/FunctionalTests/Balancer/BalancerHelpers.cs

+18-7
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ public static EndpointContext<TRequest, TResponse> CreateGrpcEndpoint<TRequest,
5454
HttpProtocols? protocols = null,
5555
bool? isHttps = null,
5656
X509Certificate2? certificate = null,
57-
ILoggerFactory? loggerFactory = null)
57+
ILoggerFactory? loggerFactory = null,
58+
Action<KestrelServerOptions>? configureServer = null)
5859
where TRequest : class, IMessage, new()
5960
where TResponse : class, IMessage, new()
6061
{
61-
var server = CreateServer(port, protocols, isHttps, certificate, loggerFactory);
62+
var server = CreateServer(port, protocols, isHttps, certificate, loggerFactory, configureServer);
6263
var method = server.DynamicGrpc.AddUnaryMethod(callHandler, methodName);
6364
var url = server.GetUrl(isHttps.GetValueOrDefault(false) ? TestServerEndpointName.Http2WithTls : TestServerEndpointName.Http2);
6465

@@ -88,7 +89,13 @@ public void Dispose()
8889
}
8990
}
9091

91-
public static GrpcTestFixture<Startup> CreateServer(int port, HttpProtocols? protocols = null, bool? isHttps = null, X509Certificate2? certificate = null, ILoggerFactory? loggerFactory = null)
92+
public static GrpcTestFixture<Startup> CreateServer(
93+
int port,
94+
HttpProtocols? protocols = null,
95+
bool? isHttps = null,
96+
X509Certificate2? certificate = null,
97+
ILoggerFactory? loggerFactory = null,
98+
Action<KestrelServerOptions>? configureServer = null)
9299
{
93100
var endpointName = isHttps.GetValueOrDefault(false) ? TestServerEndpointName.Http2WithTls : TestServerEndpointName.Http2;
94101

@@ -102,6 +109,8 @@ public static GrpcTestFixture<Startup> CreateServer(int port, HttpProtocols? pro
102109
},
103110
(options, urls) =>
104111
{
112+
configureServer?.Invoke(options);
113+
105114
urls[endpointName] = isHttps.GetValueOrDefault(false)
106115
? $"https://127.0.0.1:{port}"
107116
: $"http://127.0.0.1:{port}";
@@ -136,13 +145,14 @@ public static Task<GrpcChannel> CreateChannel(
136145
RetryPolicy? retryPolicy = null,
137146
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
138147
TimeSpan? connectTimeout = null,
139-
TimeSpan? connectionIdleTimeout = null)
148+
TimeSpan? connectionIdleTimeout = null,
149+
TimeSpan? socketPingInterval = null)
140150
{
141151
var resolver = new TestResolver();
142152
var e = endpoints.Select(i => new BalancerAddress(i.Host, i.Port)).ToList();
143153
resolver.UpdateAddresses(e);
144154

145-
return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout);
155+
return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout, socketPingInterval);
146156
}
147157

148158
public static async Task<GrpcChannel> CreateChannel(
@@ -154,12 +164,13 @@ public static async Task<GrpcChannel> CreateChannel(
154164
RetryPolicy? retryPolicy = null,
155165
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
156166
TimeSpan? connectTimeout = null,
157-
TimeSpan? connectionIdleTimeout = null)
167+
TimeSpan? connectionIdleTimeout = null,
168+
TimeSpan? socketPingInterval = null)
158169
{
159170
var services = new ServiceCollection();
160171
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
161172
services.AddSingleton<IRandomGenerator>(new TestRandomGenerator());
162-
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
173+
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(socketPingInterval ?? TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
163174
services.AddSingleton<LoadBalancerFactory>(new LeastUsedBalancerFactory());
164175

165176
var serviceConfig = new ServiceConfig();

test/FunctionalTests/Balancer/ConnectionTests.cs

+50-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
156156
}
157157

158158
// Arrange
159-
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));
159+
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod), loggerFactory: LoggerFactory);
160160

161161
var connectionIdleTimeout = TimeSpan.FromSeconds(1);
162162
var channel = await BalancerHelpers.CreateChannel(
@@ -180,6 +180,55 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
180180
AssertHasLog(LogLevel.Trace, "ConnectingOnCreateStream", "Subchannel id '1' doesn't have a connected socket available. Connecting new stream socket for 127.0.0.1:50051.");
181181
}
182182

183+
[Test]
184+
public async Task Active_UnaryCall_ServerCloseOnKeepAlive_SocketRecreatedOnRequest()
185+
{
186+
// Ignore errors
187+
SetExpectedErrorsFilter(writeContext =>
188+
{
189+
return true;
190+
});
191+
192+
Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
193+
{
194+
return Task.FromResult(new HelloReply { Message = request.Name });
195+
}
196+
197+
// In this test the client connects to the server, and the server then closes it after keep-alive is triggered.
198+
// The client then starts a gRPC call to the server. The client should discard the closed socket and create a new one.
199+
200+
// Arrange
201+
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(
202+
50051,
203+
UnaryMethod,
204+
nameof(UnaryMethod),
205+
loggerFactory: LoggerFactory,
206+
configureServer: o => o.Limits.KeepAliveTimeout = TimeSpan.FromSeconds(1));
207+
208+
// Don't timeout the socket or ping it from the client.
209+
var channel = await BalancerHelpers.CreateChannel(
210+
LoggerFactory,
211+
new RoundRobinConfig(),
212+
new[] { endpoint.Address },
213+
connectionIdleTimeout: TimeSpan.FromMinutes(30),
214+
socketPingInterval: TimeSpan.FromMinutes(30)).DefaultTimeout();
215+
216+
Logger.LogInformation("Connecting channel.");
217+
await channel.ConnectAsync();
218+
219+
// Fails when this test is run with debugging. Kestrel doesn't trigger keepalive timeout if debugging is enabled.
220+
await TestHelpers.AssertIsTrueRetryAsync(() =>
221+
{
222+
return Logs.Any(l => l.LoggerName.StartsWith("Microsoft.AspNetCore.Server.Kestrel") && l.EventId.Name == "ConnectionStop");
223+
}, "Wait for server to close connection.");
224+
225+
var client = TestClientFactory.Create(channel, endpoint.Method);
226+
var response = await client.UnaryCall(new HelloRequest { Name = "Test!" }).ResponseAsync.DefaultTimeout();
227+
228+
// Assert
229+
Assert.AreEqual("Test!", response.Message);
230+
}
231+
183232
[Test]
184233
public async Task Active_UnaryCall_MultipleStreams_UnavailableAddress_FallbackToWorkingAddress()
185234
{

0 commit comments

Comments
 (0)