Skip to content

Commit 94c089e

Browse files
authored
Fix | Fix unit test for SPN to include port number with Managed SNI (#2281)
1 parent 5cd9514 commit 94c089e

File tree

3 files changed

+113
-79
lines changed

3 files changed

+113
-79
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,10 @@ private bool InferConnectionDetails()
653653

654654
Port = port;
655655
}
656-
// Instance Name Handling. Only if we found a '\' and we did not find a port in the Data Source
657-
else if (backSlashIndex > -1)
656+
// Instance Name Handling.
657+
if (backSlashIndex > -1)
658658
{
659-
// This means that there will not be any part separated by comma.
659+
// This means that there is a part separated by '\'
660660
InstanceName = tokensByCommaAndSlash[1].Trim();
661661

662662
if (string.IsNullOrWhiteSpace(InstanceName))

src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs

+2-4
Original file line numberDiff line numberDiff line change
@@ -1004,9 +1004,6 @@ public static bool ParseDataSource(string dataSource, out string hostname, out i
10041004
port = -1;
10051005
instanceName = string.Empty;
10061006

1007-
if (dataSource.Contains(",") && dataSource.Contains("\\"))
1008-
return false;
1009-
10101007
if (dataSource.Contains(":"))
10111008
{
10121009
dataSource = dataSource.Substring(dataSource.IndexOf(":", StringComparison.Ordinal) + 1);
@@ -1018,7 +1015,8 @@ public static bool ParseDataSource(string dataSource, out string hostname, out i
10181015
{
10191016
return false;
10201017
}
1021-
dataSource = dataSource.Substring(0, dataSource.IndexOf(",", StringComparison.Ordinal) - 1);
1018+
// IndexOf is zero-based, no need to subtract one
1019+
dataSource = dataSource.Substring(0, dataSource.IndexOf(",", StringComparison.Ordinal));
10221020
}
10231021

10241022
if (dataSource.Contains("\\"))

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs

+108-72
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests
1414
{
1515
public static class InstanceNameTest
1616
{
17+
private const char SemicolonSeparator = ';';
18+
1719
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.AreConnStringsSetup))]
1820
public static void ConnectToSQLWithInstanceNameTest()
1921
{
@@ -84,138 +86,135 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove
8486
}
8587
}
8688

87-
// Note: This Unit test was tested in a domain-joined VM connecting to a remote
88-
// SQL Server using Kerberos in the same domain.
89-
[ActiveIssue("27824")] // When specifying instance name and port number, this method call always returns false
90-
[ConditionalFact(nameof(IsKerberos))]
91-
public static void PortNumberInSPNTest()
89+
#if NETCOREAPP
90+
[ConditionalFact(nameof(IsSPNPortNumberTestForTCP))]
91+
public static void PortNumberInSPNTestForTCP()
92+
{
93+
string connectionString = DataTestUtility.TCPConnectionString;
94+
SqlConnectionStringBuilder builder = new(connectionString);
95+
96+
int port = GetNamedInstancePortNumberFromSqlBrowser(connectionString);
97+
Assert.True(port > 0, "Named instance must have a valid port number.");
98+
builder.DataSource = $"{builder.DataSource},{port}";
99+
100+
PortNumberInSPNTest(builder.ConnectionString, port);
101+
}
102+
#endif
103+
104+
private static void PortNumberInSPNTest(string connectionString, int expectedPortNumber)
92105
{
93-
string connStr = DataTestUtility.TCPConnectionString;
94-
// If config.json.SupportsIntegratedSecurity = true, replace all keys defined below with Integrated Security=true
95106
if (DataTestUtility.IsIntegratedSecuritySetup())
96107
{
97108
string[] removeKeys = { "Authentication", "User ID", "Password", "UID", "PWD", "Trusted_Connection" };
98-
connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.TCPConnectionString, removeKeys) + $"Integrated Security=true";
109+
connectionString = DataTestUtility.RemoveKeysInConnStr(connectionString, removeKeys) + $"Integrated Security=true";
99110
}
100111

101-
SqlConnectionStringBuilder builder = new(connStr);
112+
SqlConnectionStringBuilder builder = new(connectionString);
113+
114+
string hostname = "";
115+
string instanceName = "";
102116

103-
Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName), "Data source to be parsed must contain a host name and instance name");
117+
DataTestUtility.ParseDataSource(builder.DataSource, out hostname, out _, out instanceName);
104118

105-
bool condition = IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName);
106-
Assert.True(condition, "Browser service is not running or instance name is invalid");
119+
Assert.False(string.IsNullOrEmpty(hostname), "Hostname must be included in the data source.");
120+
Assert.False(string.IsNullOrEmpty(instanceName), "Instance name must be included in the data source.");
107121

108-
if (condition)
122+
using (SqlConnection connection = new(builder.ConnectionString))
109123
{
110-
using SqlConnection connection = new(builder.ConnectionString);
111124
connection.Open();
112-
using SqlCommand command = new("SELECT auth_scheme, local_tcp_port from sys.dm_exec_connections where session_id = @@spid", connection);
113-
using SqlDataReader reader = command.ExecuteReader();
114-
Assert.True(reader.Read(), "Expected to receive one row data");
115-
Assert.Equal("KERBEROS", reader.GetString(0));
116-
int localTcpPort = reader.GetInt32(1);
117-
118-
int spnPort = -1;
119-
string spnInfo = GetSPNInfo(builder.DataSource, out spnPort);
120-
121-
// sample output to validate = MSSQLSvc/machine.domain.tld:spnPort"
122-
Assert.Contains($"MSSQLSvc/{hostname}", spnInfo);
123-
// the local_tcp_port should be the same as the inferred SPN port from instance name
124-
Assert.Equal(localTcpPort, spnPort);
125+
126+
string spnInfo = GetSPNInfo(builder.DataSource);
127+
Assert.Matches(@"MSSQLSvc\/.*:[\d]", spnInfo);
128+
129+
string[] spnStrs = spnInfo.Split(':');
130+
int portInSPN = 0;
131+
if (spnStrs.Length > 1)
132+
{
133+
int.TryParse(spnStrs[1], out portInSPN);
134+
}
135+
Assert.Equal(expectedPortNumber, portInSPN);
125136
}
126137
}
127138

128-
private static string GetSPNInfo(string datasource, out int out_port)
139+
private static string GetSPNInfo(string dataSource)
129140
{
130141
Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));
131142

132-
// Get all required types using reflection
133143
Type sniProxyType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SNIProxy");
134144
Type ssrpType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SSRP");
135145
Type dataSourceType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.DataSource");
136146
Type timeoutTimerType = sqlConnectionAssembly.GetType("Microsoft.Data.ProviderBase.TimeoutTimer");
137147

138-
// Used in Datasource constructor param type array
139148
Type[] dataSourceConstructorTypesArray = new Type[] { typeof(string) };
140149

141-
// Used in GetSqlServerSPNs function param types array
142150
Type[] getSqlServerSPNsTypesArray = new Type[] { dataSourceType, typeof(string) };
143151

144-
// GetPortByInstanceName parameters array
145152
Type[] getPortByInstanceNameTypesArray = new Type[] { typeof(string), typeof(string), timeoutTimerType, typeof(bool), typeof(Microsoft.Data.SqlClient.SqlConnectionIPAddressPreference) };
146153

147-
// TimeoutTimer.StartSecondsTimeout params
148154
Type[] startSecondsTimeoutTypesArray = new Type[] { typeof(int) };
149155

150-
// Get all types constructors
151-
ConstructorInfo sniProxyCtor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
152-
ConstructorInfo SSRPCtor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
153-
ConstructorInfo dataSourceCtor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
154-
ConstructorInfo timeoutTimerCtor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
156+
ConstructorInfo sniProxyConstructor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
157+
ConstructorInfo SSRPConstructor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
158+
ConstructorInfo dataSourceConstructor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
159+
ConstructorInfo timeoutTimerConstructor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
155160

156-
// Instantiate SNIProxy
157-
object sniProxy = sniProxyCtor.Invoke(new object[] { });
161+
object sniProxyObj = sniProxyConstructor.Invoke(new object[] { });
158162

159-
// Instantiate datasource
160-
object dataSourceObj = dataSourceCtor.Invoke(new object[] { datasource });
163+
object dataSourceObj = dataSourceConstructor.Invoke(new object[] { dataSource });
161164

162-
// Instantiate SSRP
163-
object ssrp = SSRPCtor.Invoke(new object[] { });
165+
object ssrpObj = SSRPConstructor.Invoke(new object[] { });
164166

165-
// Instantiate TimeoutTimer
166-
object timeoutTimer = timeoutTimerCtor.Invoke(new object[] { });
167+
object timeoutTimerObj = timeoutTimerConstructor.Invoke(new object[] { });
167168

168-
// Get TimeoutTimer.StartSecondsTimeout Method
169-
MethodInfo startSecondsTimeout = timeoutTimer.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null);
170-
// Create a timeoutTimer that expires in 30 seconds
171-
timeoutTimer = startSecondsTimeout.Invoke(dataSourceObj, new object[] { 30 });
169+
MethodInfo startSecondsTimeoutInfo = timeoutTimerObj.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null);
172170

173-
// Parse the datasource to separate the server name and instance name
174-
MethodInfo ParseServerName = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
175-
object dataSrcInfo = ParseServerName.Invoke(dataSourceObj, new object[] { datasource });
171+
timeoutTimerObj = startSecondsTimeoutInfo.Invoke(dataSourceObj, new object[] { 30 });
176172

177-
// Get the GetPortByInstanceName method of SSRP
178-
MethodInfo getPortByInstanceName = ssrp.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null);
173+
MethodInfo parseServerNameInfo = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
174+
object dataSrcInfo = parseServerNameInfo.Invoke(dataSourceObj, new object[] { dataSource });
175+
176+
MethodInfo getPortByInstanceNameInfo = ssrpObj.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null);
179177

180-
// Get the server name
181178
PropertyInfo serverInfo = dataSrcInfo.GetType().GetProperty("ServerName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
182179
string serverName = serverInfo.GetValue(dataSrcInfo, null).ToString();
183180

184-
// Get the instance name
185181
PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
186182
string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString();
187183

188-
// Get the port number using the GetPortByInstanceName method of SSRP
189-
object port = getPortByInstanceName.Invoke(ssrp, parameters: new object[] { serverName, instanceName, timeoutTimer, false, 0 });
184+
object port = getPortByInstanceNameInfo.Invoke(ssrpObj, parameters: new object[] { serverName, instanceName, timeoutTimerObj, false, 0 });
190185

191-
// Set the resolved port property of datasource
192186
PropertyInfo resolvedPortInfo = dataSrcInfo.GetType().GetProperty("ResolvedPort", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
193187
resolvedPortInfo.SetValue(dataSrcInfo, (int)port, null);
194188

195-
// Prepare the GetSqlServerSPNs method
196189
string serverSPN = "";
197-
MethodInfo getSqlServerSPNs = sniProxy.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null);
190+
MethodInfo getSqlServerSPNs = sniProxyObj.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null);
198191

199-
// Finally call GetSqlServerSPNs
200-
byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxy, new object[] { dataSrcInfo, serverSPN });
192+
byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxyObj, new object[] { dataSrcInfo, serverSPN });
201193

202-
// Example result: MSSQLSvc/machine.domain.tld:port"
203194
string spnInfo = Encoding.Unicode.GetString(result[0]);
204195

205-
out_port = (int)port;
206-
207196
return spnInfo;
208197
}
209198

210-
private static bool IsKerberos()
199+
private static bool IsSPNPortNumberTestForTCP()
211200
{
212-
return (DataTestUtility.AreConnStringsSetup()
213-
&& DataTestUtility.IsNotLocalhost()
214-
&& DataTestUtility.IsKerberosTest
215-
&& DataTestUtility.IsNotAzureServer()
201+
return (IsInstanceNameValid(DataTestUtility.TCPConnectionString)
202+
&& DataTestUtility.IsUsingManagedSNI()
203+
&& DataTestUtility.IsNotAzureServer()
216204
&& DataTestUtility.IsNotAzureSynapse());
217205
}
218206

207+
private static bool IsInstanceNameValid(string connectionString)
208+
{
209+
string instanceName = "";
210+
211+
SqlConnectionStringBuilder builder = new(connectionString);
212+
213+
bool isDataSourceValid = DataTestUtility.ParseDataSource(builder.DataSource, out _, out _, out instanceName);
214+
215+
return isDataSourceValid && !string.IsNullOrWhiteSpace(instanceName);
216+
}
217+
219218
private static bool IsBrowserAlive(string browserHostname)
220219
{
221220
const byte ClntUcastEx = 0x03;
@@ -231,6 +230,43 @@ private static bool IsValidInstance(string browserHostName, string instanceName)
231230
return response != null && response.Length > 0;
232231
}
233232

233+
private static int GetNamedInstancePortNumberFromSqlBrowser(string connectionString)
234+
{
235+
SqlConnectionStringBuilder builder = new(connectionString);
236+
237+
string hostname = "";
238+
string instanceName = "";
239+
int port = 0;
240+
241+
bool isDataSourceValid = DataTestUtility.ParseDataSource(builder.DataSource, out hostname, out _, out instanceName);
242+
Assert.True(isDataSourceValid, "DataSource is invalid");
243+
244+
bool isBrowserRunning = IsBrowserAlive(hostname);
245+
Assert.True(isBrowserRunning, "Browser service is not running.");
246+
247+
bool isInstanceExisting = IsValidInstance(hostname, instanceName);
248+
Assert.True(isInstanceExisting, "Instance name is invalid.");
249+
250+
if (isDataSourceValid && isBrowserRunning && isInstanceExisting)
251+
{
252+
byte[] request = CreateInstanceInfoRequest(instanceName);
253+
byte[] response = QueryBrowser(hostname, request);
254+
255+
string serverMessage = Encoding.ASCII.GetString(response, 3, response.Length - 3);
256+
257+
string[] elements = serverMessage.Split(SemicolonSeparator);
258+
int tcpIndex = Array.IndexOf(elements, "tcp");
259+
if (tcpIndex < 0 || tcpIndex == elements.Length - 1)
260+
{
261+
throw new SocketException();
262+
}
263+
264+
port = (int)ushort.Parse(elements[tcpIndex + 1]);
265+
}
266+
267+
return port;
268+
}
269+
234270
private static byte[] QueryBrowser(string browserHostname, byte[] requestPacket)
235271
{
236272
const int DefaultBrowserPort = 1434;

0 commit comments

Comments
 (0)