Skip to content

Commit 10421a9

Browse files
Adding queue back to maintain FIFO for extra security.
1 parent f76e1c0 commit 10421a9

File tree

2 files changed

+56
-13
lines changed

2 files changed

+56
-13
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs

+13-13
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ namespace Microsoft.Data.SqlClient.SNI
1515
/// </summary>
1616
internal class SNISslStream : SslStream
1717
{
18-
private readonly SemaphoreSlim _writeAsyncSemaphore;
19-
private readonly SemaphoreSlim _readAsyncSemaphore;
18+
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
19+
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;
2020

2121
public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback)
2222
: base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
2323
{
24-
_writeAsyncSemaphore = new SemaphoreSlim(1);
25-
_readAsyncSemaphore = new SemaphoreSlim(1);
24+
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
25+
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
2626
}
2727

2828
// Prevent ReadAsync collisions by running the task in a Semaphore Slim
@@ -31,7 +31,7 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
3131
await _readAsyncSemaphore.WaitAsync().ConfigureAwait(false);
3232
try
3333
{
34-
return await base.ReadAsync(buffer, offset, count, cancellationToken);
34+
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
3535
}
3636
finally
3737
{
@@ -45,7 +45,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
4545
await _writeAsyncSemaphore.WaitAsync().ConfigureAwait(false);
4646
try
4747
{
48-
await base.WriteAsync(buffer, offset, count, cancellationToken);
48+
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
4949
}
5050
finally
5151
{
@@ -59,22 +59,22 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
5959
/// </summary>
6060
internal class SNINetworkStream : NetworkStream
6161
{
62-
private readonly SemaphoreSlim _writeAsyncSemaphore;
63-
private readonly SemaphoreSlim _readAsyncSemaphore;
62+
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
63+
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;
6464

6565
public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket)
6666
{
67-
_writeAsyncSemaphore = new SemaphoreSlim(1);
68-
_readAsyncSemaphore = new SemaphoreSlim(1);
67+
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
68+
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
6969
}
7070

71-
// Prevent the ReadAsync collisions by running the task in a Semaphore Slim
71+
// Prevent ReadAsync collisions by running the task in a Semaphore Slim
7272
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
7373
{
7474
await _readAsyncSemaphore.WaitAsync().ConfigureAwait(false);
7575
try
7676
{
77-
return await base.ReadAsync(buffer, offset, count, cancellationToken);
77+
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
7878
}
7979
finally
8080
{
@@ -88,7 +88,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
8888
await _writeAsyncSemaphore.WaitAsync().ConfigureAwait(false);
8989
try
9090
{
91-
await base.WriteAsync(buffer, offset, count, cancellationToken);
91+
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
9292
}
9393
finally
9494
{

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs

+43
Original file line numberDiff line numberDiff line change
@@ -2133,4 +2133,47 @@ public static MethodInfo GetPromotedToken
21332133
}
21342134
}
21352135

2136+
/// <summary>
2137+
/// This class implements a FIFO Queue with SemaphoreSlim for ordered execution of parallel tasks.
2138+
/// Currently used in Managed SNI (SNISslStream) to override SslStream's WriteAsync implementation.
2139+
/// </summary>
2140+
internal class ConcurrentQueueSemaphore
2141+
{
2142+
private readonly SemaphoreSlim _semaphore;
2143+
private readonly ConcurrentQueue<TaskCompletionSource<bool>> _queue =
2144+
new ConcurrentQueue<TaskCompletionSource<bool>>();
2145+
2146+
public ConcurrentQueueSemaphore(int initialCount)
2147+
{
2148+
_semaphore = new SemaphoreSlim(initialCount);
2149+
}
2150+
2151+
public ConcurrentQueueSemaphore(int initialCount, int maxCount)
2152+
{
2153+
_semaphore = new SemaphoreSlim(initialCount, maxCount);
2154+
}
2155+
2156+
public void Wait()
2157+
{
2158+
WaitAsync().Wait();
2159+
}
2160+
2161+
public Task WaitAsync()
2162+
{
2163+
var tcs = new TaskCompletionSource<bool>();
2164+
_queue.Enqueue(tcs);
2165+
_semaphore.WaitAsync().ContinueWith(t =>
2166+
{
2167+
if (_queue.TryDequeue(out TaskCompletionSource<bool> popped))
2168+
popped.SetResult(true);
2169+
});
2170+
return tcs.Task;
2171+
}
2172+
2173+
public void Release()
2174+
{
2175+
_semaphore.Release();
2176+
}
2177+
}
2178+
21362179
}//namespace

0 commit comments

Comments
 (0)