@@ -14,6 +14,8 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests
14
14
{
15
15
public static class InstanceNameTest
16
16
{
17
+ private const char SemicolonSeparator = ';' ;
18
+
17
19
[ ConditionalFact ( typeof ( DataTestUtility ) , nameof ( DataTestUtility . IsNotAzureServer ) , nameof ( DataTestUtility . IsNotAzureSynapse ) , nameof ( DataTestUtility . AreConnStringsSetup ) ) ]
18
20
public static void ConnectToSQLWithInstanceNameTest ( )
19
21
{
@@ -84,138 +86,135 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove
84
86
}
85
87
}
86
88
87
- // Note: This Unit test was tested in a domain-joined VM connecting to a remote
88
- // SQL Server using Kerberos in the same domain.
89
- [ ActiveIssue ( "27824" ) ] // When specifying instance name and port number, this method call always returns false
90
- [ ConditionalFact ( nameof ( IsKerberos ) ) ]
91
- public static void PortNumberInSPNTest ( )
89
+ #if NETCOREAPP
90
+ [ ConditionalFact ( nameof ( IsSPNPortNumberTestForTCP ) ) ]
91
+ public static void PortNumberInSPNTestForTCP ( )
92
+ {
93
+ string connectionString = DataTestUtility . TCPConnectionString ;
94
+ SqlConnectionStringBuilder builder = new ( connectionString ) ;
95
+
96
+ int port = GetNamedInstancePortNumberFromSqlBrowser ( connectionString ) ;
97
+ Assert . True ( port > 0 , "Named instance must have a valid port number." ) ;
98
+ builder . DataSource = $ "{ builder . DataSource } ,{ port } ";
99
+
100
+ PortNumberInSPNTest ( builder . ConnectionString , port ) ;
101
+ }
102
+ #endif
103
+
104
+ private static void PortNumberInSPNTest ( string connectionString , int expectedPortNumber )
92
105
{
93
- string connStr = DataTestUtility . TCPConnectionString ;
94
- // If config.json.SupportsIntegratedSecurity = true, replace all keys defined below with Integrated Security=true
95
106
if ( DataTestUtility . IsIntegratedSecuritySetup ( ) )
96
107
{
97
108
string [ ] removeKeys = { "Authentication" , "User ID" , "Password" , "UID" , "PWD" , "Trusted_Connection" } ;
98
- connStr = DataTestUtility . RemoveKeysInConnStr ( DataTestUtility . TCPConnectionString , removeKeys ) + $ "Integrated Security=true";
109
+ connectionString = DataTestUtility . RemoveKeysInConnStr ( connectionString , removeKeys ) + $ "Integrated Security=true";
99
110
}
100
111
101
- SqlConnectionStringBuilder builder = new ( connStr ) ;
112
+ SqlConnectionStringBuilder builder = new ( connectionString ) ;
113
+
114
+ string hostname = "" ;
115
+ string instanceName = "" ;
102
116
103
- Assert . True ( DataTestUtility . ParseDataSource ( builder . DataSource , out string hostname , out _ , out string instanceName ) , "Data source to be parsed must contain a host name and instance name" ) ;
117
+ DataTestUtility . ParseDataSource ( builder . DataSource , out hostname , out _ , out instanceName ) ;
104
118
105
- bool condition = IsBrowserAlive ( hostname ) && IsValidInstance ( hostname , instanceName ) ;
106
- Assert . True ( condition , "Browser service is not running or instance name is invalid " ) ;
119
+ Assert . False ( string . IsNullOrEmpty ( hostname ) , "Hostname must be included in the data source." ) ;
120
+ Assert . False ( string . IsNullOrEmpty ( instanceName ) , "Instance name must be included in the data source. " ) ;
107
121
108
- if ( condition )
122
+ using ( SqlConnection connection = new ( builder . ConnectionString ) )
109
123
{
110
- using SqlConnection connection = new ( builder . ConnectionString ) ;
111
124
connection . Open ( ) ;
112
- using SqlCommand command = new ( "SELECT auth_scheme, local_tcp_port from sys.dm_exec_connections where session_id = @@spid" , connection ) ;
113
- using SqlDataReader reader = command . ExecuteReader ( ) ;
114
- Assert . True ( reader . Read ( ) , "Expected to receive one row data" ) ;
115
- Assert . Equal ( "KERBEROS" , reader . GetString ( 0 ) ) ;
116
- int localTcpPort = reader . GetInt32 ( 1 ) ;
117
-
118
- int spnPort = - 1 ;
119
- string spnInfo = GetSPNInfo ( builder . DataSource , out spnPort ) ;
120
-
121
- // sample output to validate = MSSQLSvc/machine.domain.tld:spnPort"
122
- Assert . Contains ( $ "MSSQLSvc/{ hostname } ", spnInfo ) ;
123
- // the local_tcp_port should be the same as the inferred SPN port from instance name
124
- Assert . Equal ( localTcpPort , spnPort ) ;
125
+
126
+ string spnInfo = GetSPNInfo ( builder . DataSource ) ;
127
+ Assert . Matches ( @"MSSQLSvc\/.*:[\d]" , spnInfo ) ;
128
+
129
+ string [ ] spnStrs = spnInfo . Split ( ':' ) ;
130
+ int portInSPN = 0 ;
131
+ if ( spnStrs . Length > 1 )
132
+ {
133
+ int . TryParse ( spnStrs [ 1 ] , out portInSPN ) ;
134
+ }
135
+ Assert . Equal ( expectedPortNumber , portInSPN ) ;
125
136
}
126
137
}
127
138
128
- private static string GetSPNInfo ( string datasource , out int out_port )
139
+ private static string GetSPNInfo ( string dataSource )
129
140
{
130
141
Assembly sqlConnectionAssembly = Assembly . GetAssembly ( typeof ( SqlConnection ) ) ;
131
142
132
- // Get all required types using reflection
133
143
Type sniProxyType = sqlConnectionAssembly . GetType ( "Microsoft.Data.SqlClient.SNI.SNIProxy" ) ;
134
144
Type ssrpType = sqlConnectionAssembly . GetType ( "Microsoft.Data.SqlClient.SNI.SSRP" ) ;
135
145
Type dataSourceType = sqlConnectionAssembly . GetType ( "Microsoft.Data.SqlClient.SNI.DataSource" ) ;
136
146
Type timeoutTimerType = sqlConnectionAssembly . GetType ( "Microsoft.Data.ProviderBase.TimeoutTimer" ) ;
137
147
138
- // Used in Datasource constructor param type array
139
148
Type [ ] dataSourceConstructorTypesArray = new Type [ ] { typeof ( string ) } ;
140
149
141
- // Used in GetSqlServerSPNs function param types array
142
150
Type [ ] getSqlServerSPNsTypesArray = new Type [ ] { dataSourceType , typeof ( string ) } ;
143
151
144
- // GetPortByInstanceName parameters array
145
152
Type [ ] getPortByInstanceNameTypesArray = new Type [ ] { typeof ( string ) , typeof ( string ) , timeoutTimerType , typeof ( bool ) , typeof ( Microsoft . Data . SqlClient . SqlConnectionIPAddressPreference ) } ;
146
153
147
- // TimeoutTimer.StartSecondsTimeout params
148
154
Type [ ] startSecondsTimeoutTypesArray = new Type [ ] { typeof ( int ) } ;
149
155
150
- // Get all types constructors
151
- ConstructorInfo sniProxyCtor = sniProxyType . GetConstructor ( BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , Type . EmptyTypes , null ) ;
152
- ConstructorInfo SSRPCtor = ssrpType . GetConstructor ( BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , Type . EmptyTypes , null ) ;
153
- ConstructorInfo dataSourceCtor = dataSourceType . GetConstructor ( BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , dataSourceConstructorTypesArray , null ) ;
154
- ConstructorInfo timeoutTimerCtor = timeoutTimerType . GetConstructor ( BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , Type . EmptyTypes , null ) ;
156
+ ConstructorInfo sniProxyConstructor = sniProxyType . GetConstructor ( BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , Type . EmptyTypes , null ) ;
157
+ ConstructorInfo SSRPConstructor = ssrpType . GetConstructor ( BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , Type . EmptyTypes , null ) ;
158
+ ConstructorInfo dataSourceConstructor = dataSourceType . GetConstructor ( BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , dataSourceConstructorTypesArray , null ) ;
159
+ ConstructorInfo timeoutTimerConstructor = timeoutTimerType . GetConstructor ( BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , Type . EmptyTypes , null ) ;
155
160
156
- // Instantiate SNIProxy
157
- object sniProxy = sniProxyCtor . Invoke ( new object [ ] { } ) ;
161
+ object sniProxyObj = sniProxyConstructor . Invoke ( new object [ ] { } ) ;
158
162
159
- // Instantiate datasource
160
- object dataSourceObj = dataSourceCtor . Invoke ( new object [ ] { datasource } ) ;
163
+ object dataSourceObj = dataSourceConstructor . Invoke ( new object [ ] { dataSource } ) ;
161
164
162
- // Instantiate SSRP
163
- object ssrp = SSRPCtor . Invoke ( new object [ ] { } ) ;
165
+ object ssrpObj = SSRPConstructor . Invoke ( new object [ ] { } ) ;
164
166
165
- // Instantiate TimeoutTimer
166
- object timeoutTimer = timeoutTimerCtor . Invoke ( new object [ ] { } ) ;
167
+ object timeoutTimerObj = timeoutTimerConstructor . Invoke ( new object [ ] { } ) ;
167
168
168
- // Get TimeoutTimer.StartSecondsTimeout Method
169
- MethodInfo startSecondsTimeout = timeoutTimer . GetType ( ) . GetMethod ( "StartSecondsTimeout" , BindingFlags . Static | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , startSecondsTimeoutTypesArray , null ) ;
170
- // Create a timeoutTimer that expires in 30 seconds
171
- timeoutTimer = startSecondsTimeout . Invoke ( dataSourceObj , new object [ ] { 30 } ) ;
169
+ MethodInfo startSecondsTimeoutInfo = timeoutTimerObj . GetType ( ) . GetMethod ( "StartSecondsTimeout" , BindingFlags . Static | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , startSecondsTimeoutTypesArray , null ) ;
172
170
173
- // Parse the datasource to separate the server name and instance name
174
- MethodInfo ParseServerName = dataSourceObj . GetType ( ) . GetMethod ( "ParseServerName" , BindingFlags . Static | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , dataSourceConstructorTypesArray , null ) ;
175
- object dataSrcInfo = ParseServerName . Invoke ( dataSourceObj , new object [ ] { datasource } ) ;
171
+ timeoutTimerObj = startSecondsTimeoutInfo . Invoke ( dataSourceObj , new object [ ] { 30 } ) ;
176
172
177
- // Get the GetPortByInstanceName method of SSRP
178
- MethodInfo getPortByInstanceName = ssrp . GetType ( ) . GetMethod ( "GetPortByInstanceName" , BindingFlags . Static | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , getPortByInstanceNameTypesArray , null ) ;
173
+ MethodInfo parseServerNameInfo = dataSourceObj . GetType ( ) . GetMethod ( "ParseServerName" , BindingFlags . Static | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , dataSourceConstructorTypesArray , null ) ;
174
+ object dataSrcInfo = parseServerNameInfo . Invoke ( dataSourceObj , new object [ ] { dataSource } ) ;
175
+
176
+ MethodInfo getPortByInstanceNameInfo = ssrpObj . GetType ( ) . GetMethod ( "GetPortByInstanceName" , BindingFlags . Static | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , getPortByInstanceNameTypesArray , null ) ;
179
177
180
- // Get the server name
181
178
PropertyInfo serverInfo = dataSrcInfo . GetType ( ) . GetProperty ( "ServerName" , BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic ) ;
182
179
string serverName = serverInfo . GetValue ( dataSrcInfo , null ) . ToString ( ) ;
183
180
184
- // Get the instance name
185
181
PropertyInfo instanceNameInfo = dataSrcInfo . GetType ( ) . GetProperty ( "InstanceName" , BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic ) ;
186
182
string instanceName = instanceNameInfo . GetValue ( dataSrcInfo , null ) . ToString ( ) ;
187
183
188
- // Get the port number using the GetPortByInstanceName method of SSRP
189
- object port = getPortByInstanceName . Invoke ( ssrp , parameters : new object [ ] { serverName , instanceName , timeoutTimer , false , 0 } ) ;
184
+ object port = getPortByInstanceNameInfo . Invoke ( ssrpObj , parameters : new object [ ] { serverName , instanceName , timeoutTimerObj , false , 0 } ) ;
190
185
191
- // Set the resolved port property of datasource
192
186
PropertyInfo resolvedPortInfo = dataSrcInfo . GetType ( ) . GetProperty ( "ResolvedPort" , BindingFlags . Instance | BindingFlags . Public | BindingFlags . NonPublic ) ;
193
187
resolvedPortInfo . SetValue ( dataSrcInfo , ( int ) port , null ) ;
194
188
195
- // Prepare the GetSqlServerSPNs method
196
189
string serverSPN = "" ;
197
- MethodInfo getSqlServerSPNs = sniProxy . GetType ( ) . GetMethod ( "GetSqlServerSPNs" , BindingFlags . Static | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , getSqlServerSPNsTypesArray , null ) ;
190
+ MethodInfo getSqlServerSPNs = sniProxyObj . GetType ( ) . GetMethod ( "GetSqlServerSPNs" , BindingFlags . Static | BindingFlags . Public | BindingFlags . NonPublic , null , CallingConventions . Any , getSqlServerSPNsTypesArray , null ) ;
198
191
199
- // Finally call GetSqlServerSPNs
200
- byte [ ] [ ] result = ( byte [ ] [ ] ) getSqlServerSPNs . Invoke ( sniProxy , new object [ ] { dataSrcInfo , serverSPN } ) ;
192
+ byte [ ] [ ] result = ( byte [ ] [ ] ) getSqlServerSPNs . Invoke ( sniProxyObj , new object [ ] { dataSrcInfo , serverSPN } ) ;
201
193
202
- // Example result: MSSQLSvc/machine.domain.tld:port"
203
194
string spnInfo = Encoding . Unicode . GetString ( result [ 0 ] ) ;
204
195
205
- out_port = ( int ) port ;
206
-
207
196
return spnInfo ;
208
197
}
209
198
210
- private static bool IsKerberos ( )
199
+ private static bool IsSPNPortNumberTestForTCP ( )
211
200
{
212
- return ( DataTestUtility . AreConnStringsSetup ( )
213
- && DataTestUtility . IsNotLocalhost ( )
214
- && DataTestUtility . IsKerberosTest
215
- && DataTestUtility . IsNotAzureServer ( )
201
+ return ( IsInstanceNameValid ( DataTestUtility . TCPConnectionString )
202
+ && DataTestUtility . IsUsingManagedSNI ( )
203
+ && DataTestUtility . IsNotAzureServer ( )
216
204
&& DataTestUtility . IsNotAzureSynapse ( ) ) ;
217
205
}
218
206
207
+ private static bool IsInstanceNameValid ( string connectionString )
208
+ {
209
+ string instanceName = "" ;
210
+
211
+ SqlConnectionStringBuilder builder = new ( connectionString ) ;
212
+
213
+ bool isDataSourceValid = DataTestUtility . ParseDataSource ( builder . DataSource , out _ , out _ , out instanceName ) ;
214
+
215
+ return isDataSourceValid && ! string . IsNullOrWhiteSpace ( instanceName ) ;
216
+ }
217
+
219
218
private static bool IsBrowserAlive ( string browserHostname )
220
219
{
221
220
const byte ClntUcastEx = 0x03 ;
@@ -231,6 +230,43 @@ private static bool IsValidInstance(string browserHostName, string instanceName)
231
230
return response != null && response . Length > 0 ;
232
231
}
233
232
233
+ private static int GetNamedInstancePortNumberFromSqlBrowser ( string connectionString )
234
+ {
235
+ SqlConnectionStringBuilder builder = new ( connectionString ) ;
236
+
237
+ string hostname = "" ;
238
+ string instanceName = "" ;
239
+ int port = 0 ;
240
+
241
+ bool isDataSourceValid = DataTestUtility . ParseDataSource ( builder . DataSource , out hostname , out _ , out instanceName ) ;
242
+ Assert . True ( isDataSourceValid , "DataSource is invalid" ) ;
243
+
244
+ bool isBrowserRunning = IsBrowserAlive ( hostname ) ;
245
+ Assert . True ( isBrowserRunning , "Browser service is not running." ) ;
246
+
247
+ bool isInstanceExisting = IsValidInstance ( hostname , instanceName ) ;
248
+ Assert . True ( isInstanceExisting , "Instance name is invalid." ) ;
249
+
250
+ if ( isDataSourceValid && isBrowserRunning && isInstanceExisting )
251
+ {
252
+ byte [ ] request = CreateInstanceInfoRequest ( instanceName ) ;
253
+ byte [ ] response = QueryBrowser ( hostname , request ) ;
254
+
255
+ string serverMessage = Encoding . ASCII . GetString ( response , 3 , response . Length - 3 ) ;
256
+
257
+ string [ ] elements = serverMessage . Split ( SemicolonSeparator ) ;
258
+ int tcpIndex = Array . IndexOf ( elements , "tcp" ) ;
259
+ if ( tcpIndex < 0 || tcpIndex == elements . Length - 1 )
260
+ {
261
+ throw new SocketException ( ) ;
262
+ }
263
+
264
+ port = ( int ) ushort . Parse ( elements [ tcpIndex + 1 ] ) ;
265
+ }
266
+
267
+ return port ;
268
+ }
269
+
234
270
private static byte [ ] QueryBrowser ( string browserHostname , byte [ ] requestPacket )
235
271
{
236
272
const int DefaultBrowserPort = 1434 ;
0 commit comments