Skip to content

Commit fc6aece

Browse files
authored
Fix SqlSequentialStream multipacket reads stalling and add covering test (#603)
1 parent e1d16ad commit fc6aece

File tree

2 files changed

+106
-6
lines changed

2 files changed

+106
-6
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs

+14-5
Original file line numberDiff line numberDiff line change
@@ -4517,22 +4517,31 @@ private Task<int> GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, b
45174517
SetTimeout(_defaultTimeoutMilliseconds);
45184518

45194519
// Try to read without any continuations (all the data may already be in the stateObj's buffer)
4520-
if (!TryGetBytesInternalSequential(context.columnIndex, context.buffer, context.index, context.length, out bytesRead))
4520+
bool filledBuffer = context._reader.TryGetBytesInternalSequential(
4521+
context.columnIndex,
4522+
context.buffer,
4523+
context.index + context.totalBytesRead,
4524+
context.length - context.totalBytesRead,
4525+
out bytesRead
4526+
);
4527+
context.totalBytesRead += bytesRead;
4528+
Debug.Assert(context.totalBytesRead <= context.length, "Read more bytes than required");
4529+
4530+
if (!filledBuffer)
45214531
{
45224532
// This will be the 'state' for the callback
4523-
int totalBytesRead = bytesRead;
4524-
45254533
if (!isContinuation)
45264534
{
45274535
// This is the first async operation which is happening - setup the _currentTask and timeout
4536+
Debug.Assert(context._source==null, "context._source should not be non-null when trying to change to async");
45284537
source = new TaskCompletionSource<int>();
45294538
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
45304539
if (original != null)
45314540
{
45324541
source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
45334542
return source.Task;
45344543
}
4535-
4544+
context._source = source;
45364545
// Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
45374546
if (_cancelAsyncOnCloseToken.IsCancellationRequested)
45384547
{
@@ -4561,7 +4570,7 @@ private Task<int> GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, b
45614570
}
45624571
else
45634572
{
4564-
Debug.Assert(context._source != null, "context.source should not be null when continuing");
4573+
Debug.Assert(context._source != null, "context._source should not be null when continuing");
45654574
// setup for cleanup/completing
45664575
retryTask.ContinueWith(
45674576
continuationAction: AAsyncCallContext<int>.s_completeCallback,

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs

+92-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,97 @@ public static void RunAllTestsForSingleServer_TCP()
3737
RunAllTestsForSingleServer(DataTestUtility.TCPConnectionString);
3838
}
3939

40+
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
41+
public static async Task AsyncMultiPacketStreamRead()
42+
{
43+
int packetSize = 514; // force small packet size so we can quickly check multi packet reads
44+
45+
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString);
46+
builder.PacketSize = 514;
47+
string connectionString = builder.ToString();
48+
49+
byte[] inputData = null;
50+
byte[] outputData = null;
51+
string tableName = DataTestUtility.GetUniqueNameForSqlServer("data");
52+
53+
using (SqlConnection connection = new SqlConnection(connectionString))
54+
{
55+
await connection.OpenAsync();
56+
57+
try
58+
{
59+
inputData = CreateBinaryTable(connection, tableName, packetSize);
60+
61+
using (SqlCommand command = new SqlCommand($"SELECT foo FROM {tableName}", connection))
62+
using (SqlDataReader reader = await command.ExecuteReaderAsync(System.Data.CommandBehavior.SequentialAccess))
63+
{
64+
await reader.ReadAsync();
65+
66+
using (Stream stream = reader.GetStream(0))
67+
using (CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(60)))
68+
using (MemoryStream memory = new MemoryStream(16 * 1024))
69+
{
70+
await stream.CopyToAsync(memory, 37, cancellationTokenSource.Token); // prime number sized buffer to cause many cross packet partial reads
71+
outputData = memory.ToArray();
72+
}
73+
}
74+
}
75+
finally
76+
{
77+
DataTestUtility.DropTable(connection, tableName);
78+
}
79+
}
80+
81+
Assert.NotNull(outputData);
82+
int sharedLength = Math.Min(inputData.Length, outputData.Length);
83+
if (sharedLength < outputData.Length)
84+
{
85+
Assert.False(true, $"output is longer than input, input={inputData.Length} bytes, output={outputData.Length} bytes");
86+
}
87+
if (sharedLength < inputData.Length)
88+
{
89+
Assert.False(true, $"input is longer than output, input={inputData.Length} bytes, output={outputData.Length} bytes");
90+
}
91+
for (int index = 0; index < sharedLength; index++)
92+
{
93+
if (inputData[index] != outputData[index]) // avoid formatting the output string unless there is a difference
94+
{
95+
Assert.True(false, $"input and output differ at index {index}, input={inputData[index]}, output={outputData[index]}");
96+
}
97+
}
98+
99+
}
100+
101+
private static byte[] CreateBinaryTable(SqlConnection connection, string tableName, int packetSize)
102+
{
103+
byte[] pattern = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13 };
104+
byte[] data = new byte[packetSize * 10];
105+
int position = 0;
106+
while (position < data.Length)
107+
{
108+
int copyCount = Math.Min(pattern.Length, data.Length - position);
109+
Array.Copy(pattern, 0, data, position, copyCount);
110+
position += copyCount;
111+
}
112+
113+
using (var cmd = connection.CreateCommand())
114+
{
115+
cmd.CommandText = $@"
116+
IF OBJECT_ID('dbo.{tableName}', 'U') IS NOT NULL
117+
DROP TABLE {tableName};
118+
CREATE TABLE {tableName} (id INT, foo VARBINARY(MAX))
119+
";
120+
cmd.ExecuteNonQuery();
121+
122+
cmd.CommandText = $"INSERT INTO {tableName} (id, foo) VALUES (@id, @foo)";
123+
cmd.Parameters.AddWithValue("id", 1);
124+
cmd.Parameters.AddWithValue("foo", data);
125+
cmd.ExecuteNonQuery();
126+
}
127+
128+
return data;
129+
}
130+
40131
private static void RunAllTestsForSingleServer(string connectionString, bool usingNamePipes = false)
41132
{
42133
RowBuffer(connectionString);
@@ -1811,7 +1902,7 @@ private static void TestXEventsStreaming(string connectionString)
18111902
SqlDataReader reader = cmd.ExecuteReader(System.Data.CommandBehavior.SequentialAccess);
18121903
for (int i = 0; i < streamXeventCount && reader.Read(); i++)
18131904
{
1814-
Int32 colType = reader.GetInt32(0);
1905+
int colType = reader.GetInt32(0);
18151906
int cb = (int)reader.GetBytes(1, 0, null, 0, 0);
18161907

18171908
byte[] bytes = new byte[cb];

0 commit comments

Comments
 (0)