Skip to content

Commit cde615e

Browse files
Fixes "InvalidOperationException" errors by performing async operations in SemaphoreSlim (#796)
1 parent f0572f3 commit cde615e

File tree

10 files changed

+201
-67
lines changed

10 files changed

+201
-67
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@
442442
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
443443
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIProxy.cs" />
444444
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
445+
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.cs" />
445446
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.cs" />
446447
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
447448
<Compile Include="Microsoft\Data\SqlClient\SNI\SspiClientContextStatus.cs" />

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, object
9393
}
9494

9595
_sslOverTdsStream = new SslOverTdsStream(_pipeStream);
96-
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
96+
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
9797

9898
_stream = _pipeStream;
9999
_status = TdsEnums.SNI_SUCCESS;
@@ -286,7 +286,7 @@ public override uint Send(SNIPacket packet)
286286
}
287287

288288
// this lock ensures that two packets are not being written to the transport at the same time
289-
// so that sending a standard and an out-of-band packet are both written atomically no data is
289+
// so that sending a standard and an out-of-band packet are both written atomically no data is
290290
// interleaved
291291
lock (_sendSync)
292292
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Net.Security;
6+
using System.IO;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
using System.Net.Sockets;
10+
11+
namespace Microsoft.Data.SqlClient.SNI
12+
{
13+
/// <summary>
14+
/// This class extends SslStream to customize stream behavior for Managed SNI implementation.
15+
/// </summary>
16+
internal class SNISslStream : SslStream
17+
{
18+
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
19+
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;
20+
21+
public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback)
22+
: base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
23+
{
24+
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
25+
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
26+
}
27+
28+
// Prevent ReadAsync collisions by running the task in a Semaphore Slim
29+
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
30+
{
31+
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
32+
try
33+
{
34+
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
35+
}
36+
finally
37+
{
38+
_readAsyncSemaphore.Release();
39+
}
40+
}
41+
42+
// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
43+
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
44+
{
45+
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
46+
try
47+
{
48+
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
49+
}
50+
finally
51+
{
52+
_writeAsyncSemaphore.Release();
53+
}
54+
}
55+
}
56+
57+
/// <summary>
58+
/// This class extends NetworkStream to customize stream behavior for Managed SNI implementation.
59+
/// </summary>
60+
internal class SNINetworkStream : NetworkStream
61+
{
62+
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
63+
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;
64+
65+
public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket)
66+
{
67+
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
68+
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
69+
}
70+
71+
// Prevent ReadAsync collisions by running the task in a Semaphore Slim
72+
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
73+
{
74+
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
75+
try
76+
{
77+
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
78+
}
79+
finally
80+
{
81+
_readAsyncSemaphore.Release();
82+
}
83+
}
84+
85+
// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
86+
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
87+
{
88+
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
89+
try
90+
{
91+
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
92+
}
93+
finally
94+
{
95+
_writeAsyncSemaphore.Release();
96+
}
97+
}
98+
}
99+
}

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs

+54-48
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
143143
bool reportError = true;
144144

145145
// We will always first try to connect with serverName as before and let the DNS server to resolve the serverName.
146-
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
146+
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
147147
// IPv4 fails. The exceptions will be throw to upper level and be handled as before.
148148
try
149149
{
@@ -160,14 +160,14 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
160160
{
161161
// Retry with cached IP address
162162
if (ex is SocketException || ex is ArgumentException || ex is AggregateException)
163-
{
163+
{
164164
if (hasCachedDNSInfo == false)
165165
{
166166
throw;
167167
}
168168
else
169169
{
170-
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);
170+
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);
171171

172172
try
173173
{
@@ -180,9 +180,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
180180
_socket = Connect(cachedDNSInfo.AddrIPv4, portRetry, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo);
181181
}
182182
}
183-
catch(Exception exRetry)
183+
catch (Exception exRetry)
184184
{
185-
if (exRetry is SocketException || exRetry is ArgumentNullException
185+
if (exRetry is SocketException || exRetry is ArgumentNullException
186186
|| exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException)
187187
{
188188
if (parallel)
@@ -199,7 +199,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
199199
throw;
200200
}
201201
}
202-
}
202+
}
203203
}
204204
else
205205
{
@@ -223,10 +223,10 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
223223
}
224224

225225
_socket.NoDelay = true;
226-
_tcpStream = new NetworkStream(_socket, true);
226+
_tcpStream = new SNINetworkStream(_socket, true);
227227

228228
_sslOverTdsStream = new SslOverTdsStream(_tcpStream);
229-
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
229+
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
230230
}
231231
catch (SocketException se)
232232
{
@@ -331,7 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
331331
}
332332

333333
CancellationTokenSource cts = null;
334-
334+
335335
void Cancel()
336336
{
337337
for (int i = 0; i < sockets.Length; ++i)
@@ -355,7 +355,7 @@ void Cancel()
355355
}
356356

357357
Socket availableSocket = null;
358-
try
358+
try
359359
{
360360
for (int i = 0; i < sockets.Length; ++i)
361361
{
@@ -566,45 +566,45 @@ public override uint Send(SNIPacket packet)
566566
{
567567
bool releaseLock = false;
568568
try
569-
{
570-
// is the packet is marked out out-of-band (attention packets only) it must be
571-
// sent immediately even if a send of recieve operation is already in progress
572-
// because out of band packets are used to cancel ongoing operations
573-
// so try to take the lock if possible but continue even if it can't be taken
574-
if (packet.IsOutOfBand)
575-
{
576-
Monitor.TryEnter(this, ref releaseLock);
577-
}
578-
else
579-
{
580-
Monitor.Enter(this);
581-
releaseLock = true;
582-
}
583-
584-
// this lock ensures that two packets are not being written to the transport at the same time
585-
// so that sending a standard and an out-of-band packet are both written atomically no data is
586-
// interleaved
587-
lock (_sendSync)
588569
{
589-
try
590-
{
591-
packet.WriteToStream(_stream);
592-
return TdsEnums.SNI_SUCCESS;
593-
}
594-
catch (ObjectDisposedException ode)
570+
// is the packet is marked out out-of-band (attention packets only) it must be
571+
// sent immediately even if a send of recieve operation is already in progress
572+
// because out of band packets are used to cancel ongoing operations
573+
// so try to take the lock if possible but continue even if it can't be taken
574+
if (packet.IsOutOfBand)
595575
{
596-
return ReportTcpSNIError(ode);
576+
Monitor.TryEnter(this, ref releaseLock);
597577
}
598-
catch (SocketException se)
578+
else
599579
{
600-
return ReportTcpSNIError(se);
580+
Monitor.Enter(this);
581+
releaseLock = true;
601582
}
602-
catch (IOException ioe)
583+
584+
// this lock ensures that two packets are not being written to the transport at the same time
585+
// so that sending a standard and an out-of-band packet are both written atomically no data is
586+
// interleaved
587+
lock (_sendSync)
603588
{
604-
return ReportTcpSNIError(ioe);
589+
try
590+
{
591+
packet.WriteToStream(_stream);
592+
return TdsEnums.SNI_SUCCESS;
593+
}
594+
catch (ObjectDisposedException ode)
595+
{
596+
return ReportTcpSNIError(ode);
597+
}
598+
catch (SocketException se)
599+
{
600+
return ReportTcpSNIError(se);
601+
}
602+
catch (IOException ioe)
603+
{
604+
return ReportTcpSNIError(ioe);
605+
}
605606
}
606607
}
607-
}
608608
finally
609609
{
610610
if (releaseLock)
@@ -633,7 +633,8 @@ public override uint Receive(out SNIPacket packet, int timeoutInMilliseconds)
633633
_socket.ReceiveTimeout = timeoutInMilliseconds;
634634
}
635635
else if (timeoutInMilliseconds == -1)
636-
{ // SqlCient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
636+
{
637+
// SqlClient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
637638
_socket.ReceiveTimeout = 0;
638639
}
639640
else
@@ -706,12 +707,17 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
706707
/// <returns>SNI error code</returns>
707708
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
708709
{
709-
SNIAsyncCallback cb = callback ?? _sendCallback;
710-
lock (this)
710+
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent("<sc.SNI.SNIMarsHandle.SendAsync |SNI|INFO|SCOPE>");
711+
try
711712
{
713+
SNIAsyncCallback cb = callback ?? _sendCallback;
712714
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
715+
return TdsEnums.SNI_SUCCESS_IO_PENDING;
716+
}
717+
finally
718+
{
719+
SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
713720
}
714-
return TdsEnums.SNI_SUCCESS_IO_PENDING;
715721
}
716722

717723
/// <summary>
@@ -745,15 +751,15 @@ public override uint CheckConnection()
745751
{
746752
try
747753
{
748-
// _socket.Poll method with argument SelectMode.SelectRead returns
754+
// _socket.Poll method with argument SelectMode.SelectRead returns
749755
// True : if Listen has been called and a connection is pending, or
750756
// True : if data is available for reading, or
751757
// True : if the connection has been closed, reset, or terminated, i.e no active connection.
752758
// False : otherwise.
753759
// _socket.Available property returns the number of bytes of data available to read.
754760
//
755-
// Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
756-
// combination with _socket.Poll method and _socket.Available == 0 check. When both of them
761+
// Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
762+
// combination with _socket.Poll method and _socket.Available == 0 check. When both of them
757763
// return true we can safely determine that the connection is no longer active.
758764
if (!_socket.Connected || (_socket.Poll(100, SelectMode.SelectRead) && _socket.Available == 0))
759765
{

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs

+4-13
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,16 @@ namespace Microsoft.Data.SqlClient.SNI
1212
internal sealed partial class SslOverTdsStream
1313
{
1414
public override int Read(byte[] buffer, int offset, int count)
15-
{
16-
return Read(buffer.AsSpan(offset, count));
17-
}
15+
=> Read(buffer.AsSpan(offset, count));
1816

1917
public override void Write(byte[] buffer, int offset, int count)
20-
{
21-
Write(buffer.AsSpan(offset, count));
22-
}
18+
=> Write(buffer.AsSpan(offset, count));
2319

2420
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
25-
{
26-
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
27-
}
21+
=> ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
2822

2923
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
30-
{
31-
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
32-
}
24+
=> WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
3325

3426
public override int Read(Span<byte> buffer)
3527
{
@@ -288,7 +280,6 @@ public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, Cancella
288280

289281
await _stream.FlushAsync().ConfigureAwait(false);
290282

291-
292283
remaining = remaining.Slice(dataLength);
293284
}
294285
}

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1319,7 +1319,7 @@ private void ThrowIfReconnectionHasBeenCanceled()
13191319
if (_stateObj == null)
13201320
{
13211321
var reconnectionCompletionSource = _reconnectionCompletionSource;
1322-
if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task.IsCanceled)
1322+
if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task != null && reconnectionCompletionSource.Task.IsCanceled)
13231323
{
13241324
throw SQL.CR_ReconnectionCancelled();
13251325
}

0 commit comments

Comments
 (0)